diff --git a/.gitignore b/.gitignore index 903297db9690..e4c44d0590d5 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,8 @@ project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip python/deps +python/test_coverage/coverage_data +python/test_coverage/htmlcov python/pyspark/python reports/ scalastyle-on-compile.generated.xml @@ -74,6 +76,7 @@ streaming-tests.log target/ unit-tests.log work/ +docs/.jekyll-metadata # For Hive TempStatsStore/ diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 6d46c3190626..855eb5bf77f1 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.3.0 +Version: 2.4.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47..8cd00352d195 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -151,6 +151,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "repartitionByRange", "rollup", "sample", "sample_frac", @@ -179,6 +180,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", @@ -199,6 +201,9 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_max", + "array_min", + "array_position", "asc", "ascii", "asin", @@ -243,6 +248,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", @@ -252,6 +258,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb..a1c9495b0795 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -36,7 +36,6 @@ setOldClass("structType") #' @slot sdf A Java object reference to the backing Scala DataFrame #' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -77,7 +76,6 @@ setWriteMode <- function(write, mode) { write } -#' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached #' @noRd @@ -97,7 +95,6 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @rdname printSchema #' @name printSchema #' @aliases printSchema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -123,7 +120,6 @@ setMethod("printSchema", #' @rdname schema #' @name schema #' @aliases schema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -146,7 +142,6 @@ setMethod("schema", #' @aliases explain,SparkDataFrame-method #' @rdname explain #' @name explain -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -178,7 +173,6 @@ setMethod("explain", #' @rdname isLocal #' @name isLocal #' @aliases isLocal,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -209,7 +203,6 @@ setMethod("isLocal", #' @aliases showDF,SparkDataFrame-method #' @rdname showDF #' @name showDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -241,7 +234,6 @@ setMethod("showDF", #' @rdname show #' @aliases show,SparkDataFrame-method #' @name show -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -269,7 +261,6 @@ setMethod("show", "SparkDataFrame", #' @rdname dtypes #' @name dtypes #' @aliases dtypes,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -296,7 +287,6 @@ setMethod("dtypes", #' @rdname columns #' @name columns #' @aliases columns,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -388,7 +378,6 @@ setMethod("colnames<-", #' @aliases coltypes,SparkDataFrame-method #' @name coltypes #' @family SparkDataFrame functions -#' @export #' @examples #'\dontrun{ #' irisDF <- createDataFrame(iris) @@ -445,7 +434,6 @@ setMethod("coltypes", #' @rdname coltypes #' @name coltypes<- #' @aliases coltypes<-,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -494,7 +482,6 @@ setMethod("coltypes<-", #' @rdname createOrReplaceTempView #' @name createOrReplaceTempView #' @aliases createOrReplaceTempView,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -521,7 +508,6 @@ setMethod("createOrReplaceTempView", #' @rdname registerTempTable-deprecated #' @name registerTempTable #' @aliases registerTempTable,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -552,7 +538,6 @@ setMethod("registerTempTable", #' @rdname insertInto #' @name insertInto #' @aliases insertInto,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -580,7 +565,6 @@ setMethod("insertInto", #' @aliases cache,SparkDataFrame-method #' @rdname cache #' @name cache -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -611,7 +595,6 @@ setMethod("cache", #' @rdname persist #' @name persist #' @aliases persist,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -641,7 +624,6 @@ setMethod("persist", #' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -669,7 +651,6 @@ setMethod("unpersist", #' @rdname storageLevel #' @aliases storageLevel,SparkDataFrame-method #' @name storageLevel -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -706,8 +687,7 @@ setMethod("storageLevel", #' @rdname coalesce #' @name coalesce #' @aliases coalesce,SparkDataFrame-method -#' @seealso \link{repartition} -#' @export +#' @seealso \link{repartition}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -743,8 +723,7 @@ setMethod("coalesce", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method -#' @seealso \link{coalesce} -#' @export +#' @seealso \link{coalesce}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -780,6 +759,67 @@ setMethod("repartition", dataFrame(sdf) }) + +#' Repartition by range +#' +#' The following options for repartition by range are possible: +#' \itemize{ +#' \item{1.} {Return a new SparkDataFrame range partitioned by +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame range partitioned by the given column(s), +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} +#'} +#' +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the range partitioning will be performed. +#' @param ... additional column(s) to be used in the range partitioning. +#' +#' @family SparkDataFrame functions +#' @rdname repartitionByRange +#' @name repartitionByRange +#' @aliases repartitionByRange,SparkDataFrame-method +#' @seealso \link{repartition}, \link{coalesce} +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- repartitionByRange(df, col = df$col1, df$col2) +#' newDF <- repartitionByRange(df, 3L, col = df$col1, df$col2) +#'} +#' @note repartitionByRange since 2.4.0 +setMethod("repartitionByRange", + signature(x = "SparkDataFrame"), + function(x, numPartitions = NULL, col = NULL, ...) { + if (!is.null(numPartitions) && !is.null(col)) { + # number of partitions and columns both are specified + if (is.numeric(numPartitions) && class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", numToInt(numPartitions), jcol) + } else { + stop(paste("numPartitions and col must be numeric and Column; however, got", + class(numPartitions), "and", class(col))) + } + } else if (!is.null(col)) { + # only columns are specified + if (class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", jcol) + } else { + stop(paste("col must be Column; however, got", class(col))) + } + } else if (!is.null(numPartitions)) { + # only numPartitions is specified + stop("At least one partition-by column must be specified.") + } else { + stop("Please, specify a column(s) or the number of partitions with a column(s)") + } + dataFrame(sdf) + }) + #' toJSON #' #' Converts a SparkDataFrame into a SparkDataFrame of JSON string. @@ -793,7 +833,6 @@ setMethod("repartition", #' @rdname toJSON #' @name toJSON #' @aliases toJSON,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -826,7 +865,6 @@ setMethod("toJSON", #' @rdname write.json #' @name write.json #' @aliases write.json,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -858,7 +896,6 @@ setMethod("write.json", #' @aliases write.orc,SparkDataFrame,character-method #' @rdname write.orc #' @name write.orc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -890,7 +927,6 @@ setMethod("write.orc", #' @rdname write.parquet #' @name write.parquet #' @aliases write.parquet,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -911,7 +947,6 @@ setMethod("write.parquet", #' @rdname write.parquet #' @name saveAsParquetFile #' @aliases saveAsParquetFile,SparkDataFrame,character-method -#' @export #' @note saveAsParquetFile since 1.4.0 setMethod("saveAsParquetFile", signature(x = "SparkDataFrame", path = "character"), @@ -936,7 +971,6 @@ setMethod("saveAsParquetFile", #' @aliases write.text,SparkDataFrame,character-method #' @rdname write.text #' @name write.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -963,7 +997,6 @@ setMethod("write.text", #' @aliases distinct,SparkDataFrame-method #' @rdname distinct #' @name distinct -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1004,7 +1037,6 @@ setMethod("unique", #' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1061,7 +1093,6 @@ setMethod("sample_frac", #' @rdname nrow #' @name nrow #' @aliases count,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1094,7 +1125,6 @@ setMethod("nrow", #' @rdname ncol #' @name ncol #' @aliases ncol,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1118,7 +1148,6 @@ setMethod("ncol", #' @rdname dim #' @aliases dim,SparkDataFrame-method #' @name dim -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1144,7 +1173,6 @@ setMethod("dim", #' @rdname collect #' @aliases collect,SparkDataFrame-method #' @name collect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1229,7 +1257,6 @@ setMethod("collect", #' @rdname limit #' @name limit #' @aliases limit,SparkDataFrame,numeric-method -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -1253,7 +1280,6 @@ setMethod("limit", #' @rdname take #' @name take #' @aliases take,SparkDataFrame,numeric-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1282,7 +1308,6 @@ setMethod("take", #' @aliases head,SparkDataFrame-method #' @rdname head #' @name head -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1307,7 +1332,6 @@ setMethod("head", #' @aliases first,SparkDataFrame-method #' @rdname first #' @name first -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1359,7 +1383,6 @@ setMethod("toRDD", #' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy #' @name groupBy -#' @export #' @examples #' \dontrun{ #' # Compute the average for all numeric columns grouped by department. @@ -1401,7 +1424,6 @@ setMethod("group_by", #' @aliases agg,SparkDataFrame-method #' @rdname summarize #' @name agg -#' @export #' @note agg since 1.4.0 setMethod("agg", signature(x = "SparkDataFrame"), @@ -1460,7 +1482,6 @@ setClassUnion("characterOrstructType", c("character", "structType")) #' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1519,7 +1540,6 @@ setMethod("dapply", #' @aliases dapplyCollect,SparkDataFrame,function-method #' @name dapplyCollect #' @seealso \link{dapply} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1576,7 +1596,6 @@ setMethod("dapplyCollect", #' @rdname gapply #' @name gapply #' @seealso \link{gapplyCollect} -#' @export #' @examples #' #' \dontrun{ @@ -1673,7 +1692,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @name gapplyCollect #' @seealso \link{gapply} -#' @export #' @examples #' #' \dontrun{ @@ -1947,7 +1965,6 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected #' columns. -#' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method #' @seealso \link{withColumn} @@ -1992,7 +2009,6 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' If more than one column is assigned in \code{col}, \code{...} #' should be left empty. #' @return A new SparkDataFrame with selected columns. -#' @export #' @family SparkDataFrame functions #' @rdname select #' @aliases select,SparkDataFrame,character-method @@ -2024,7 +2040,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "character"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,Column-method #' @note select(SparkDataFrame, Column) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "Column"), @@ -2037,7 +2052,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "Column"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,list-method #' @note select(SparkDataFrame, list) since 1.4.0 setMethod("select", @@ -2066,7 +2080,6 @@ setMethod("select", #' @aliases selectExpr,SparkDataFrame,character-method #' @rdname selectExpr #' @name selectExpr -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2090,14 +2103,14 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this SparkDataFrame), or an atomic +#' vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} \link{subset} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2136,7 +2149,6 @@ setMethod("withColumn", #' @rdname mutate #' @name mutate #' @seealso \link{rename} \link{withColumn} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2207,7 +2219,6 @@ setMethod("mutate", }) #' @param _data a SparkDataFrame. -#' @export #' @rdname mutate #' @aliases transform,SparkDataFrame-method #' @name transform @@ -2231,7 +2242,6 @@ setMethod("transform", #' @name withColumnRenamed #' @aliases withColumnRenamed,SparkDataFrame,character,character-method #' @seealso \link{mutate} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2257,7 +2267,6 @@ setMethod("withColumnRenamed", #' @rdname rename #' @name rename #' @aliases rename,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2303,7 +2312,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @aliases arrange,SparkDataFrame,Column-method #' @rdname arrange #' @name arrange -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2334,7 +2342,6 @@ setMethod("arrange", #' @rdname arrange #' @name arrange #' @aliases arrange,SparkDataFrame,character-method -#' @export #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), @@ -2367,7 +2374,6 @@ setMethod("arrange", #' @rdname arrange #' @aliases orderBy,SparkDataFrame,characterOrColumn-method -#' @export #' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 setMethod("orderBy", signature(x = "SparkDataFrame", col = "characterOrColumn"), @@ -2388,7 +2394,6 @@ setMethod("orderBy", #' @rdname filter #' @name filter #' @family subsetting functions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2431,7 +2436,6 @@ setMethod("where", #' @aliases dropDuplicates,SparkDataFrame-method #' @rdname dropDuplicates #' @name dropDuplicates -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2480,7 +2484,6 @@ setMethod("dropDuplicates", #' @rdname join #' @name join #' @seealso \link{merge} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2532,7 +2535,6 @@ setMethod("join", #' @rdname crossJoin #' @name crossJoin #' @seealso \link{merge} \link{join} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2580,7 +2582,6 @@ setMethod("crossJoin", #' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge #' @seealso \link{join} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2720,7 +2721,6 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { #' @name union #' @aliases union,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2741,7 +2741,6 @@ setMethod("union", #' @rdname union #' @name unionAll #' @aliases unionAll,SparkDataFrame,SparkDataFrame-method -#' @export #' @note unionAll since 1.4.0 setMethod("unionAll", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2768,7 +2767,6 @@ setMethod("unionAll", #' @name unionByName #' @aliases unionByName,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{union} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2801,7 +2799,6 @@ setMethod("unionByName", #' @rdname rbind #' @name rbind #' @seealso \link{union} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2834,7 +2831,6 @@ setMethod("rbind", #' @aliases intersect,SparkDataFrame,SparkDataFrame-method #' @rdname intersect #' @name intersect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2853,7 +2849,7 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT DISTINCT} in SQL. #' #' @param x a SparkDataFrame. #' @param y a SparkDataFrame. @@ -2862,7 +2858,6 @@ setMethod("intersect", #' @aliases except,SparkDataFrame,SparkDataFrame-method #' @rdname except #' @name except -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2871,7 +2866,6 @@ setMethod("intersect", #' exceptDF <- except(df, df2) #' } #' @rdname except -#' @export #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2908,7 +2902,6 @@ setMethod("except", #' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2943,7 +2936,6 @@ setMethod("write.df", #' @rdname write.df #' @name saveDF #' @aliases saveDF,SparkDataFrame,character-method -#' @export #' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), @@ -2977,7 +2969,6 @@ setMethod("saveDF", #' @aliases saveAsTable,SparkDataFrame,character-method #' @rdname saveAsTable #' @name saveAsTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3014,7 +3005,6 @@ setMethod("saveAsTable", #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname describe #' @name describe -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3054,10 +3044,10 @@ setMethod("describe", #' \item stddev #' \item min #' \item max -#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75\%") #' } #' If no statistics are given, this function computes count, mean, stddev, min, -#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' approximate quartiles (percentiles at 25\%, 50\%, and 75\%), and max. #' This function is meant for exploratory data analysis, as we make no guarantee about the #' backward compatibility of the schema of the resulting Dataset. If you want to #' programmatically compute summary statistics, use the \code{agg} function instead. @@ -3070,7 +3060,6 @@ setMethod("describe", #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3116,7 +3105,6 @@ setMethod("summary", #' @rdname nafunctions #' @aliases dropna,SparkDataFrame-method #' @name dropna -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3147,7 +3135,6 @@ setMethod("dropna", #' @rdname nafunctions #' @name na.omit #' @aliases na.omit,SparkDataFrame-method -#' @export #' @note na.omit since 1.5.0 setMethod("na.omit", signature(object = "SparkDataFrame"), @@ -3167,7 +3154,6 @@ setMethod("na.omit", #' @rdname nafunctions #' @name fillna #' @aliases fillna,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3398,7 +3384,6 @@ setMethod("str", #' @rdname drop #' @name drop #' @aliases drop,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3426,7 +3411,6 @@ setMethod("drop", #' @name drop #' @rdname drop #' @aliases drop,ANY-method -#' @export setMethod("drop", signature(x = "ANY"), function(x) { @@ -3445,7 +3429,6 @@ setMethod("drop", #' @rdname histogram #' @aliases histogram,SparkDataFrame,characterOrColumn-method #' @family SparkDataFrame functions -#' @export #' @examples #' \dontrun{ #' @@ -3581,7 +3564,6 @@ setMethod("histogram", #' @rdname write.jdbc #' @name write.jdbc #' @aliases write.jdbc,SparkDataFrame,character,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3610,7 +3592,6 @@ setMethod("write.jdbc", #' @aliases randomSplit,SparkDataFrame,numeric-method #' @rdname randomSplit #' @name randomSplit -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3644,7 +3625,6 @@ setMethod("randomSplit", #' @aliases getNumPartitions,SparkDataFrame-method #' @rdname getNumPartitions #' @name getNumPartitions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3661,7 +3641,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3670,7 +3651,6 @@ setMethod("getNumPartitions", #' @rdname isStreaming #' @name isStreaming #' @seealso \link{read.stream} \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3707,14 +3687,23 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} #' @aliases write.stream,SparkDataFrame-method #' @rdname write.stream #' @name write.stream -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3725,7 +3714,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3727,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3739,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3774,7 +3796,6 @@ setMethod("write.stream", #' @rdname checkpoint #' @name checkpoint #' @seealso \link{setCheckpointDir} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") @@ -3802,7 +3823,6 @@ setMethod("checkpoint", #' @aliases localCheckpoint,SparkDataFrame-method #' @rdname localCheckpoint #' @name localCheckpoint -#' @export #' @examples #'\dontrun{ #' df <- localCheckpoint(df) @@ -3829,7 +3849,6 @@ setMethod("localCheckpoint", #' @aliases cube,SparkDataFrame-method #' @rdname cube #' @name cube -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3864,7 +3883,6 @@ setMethod("cube", #' @aliases rollup,SparkDataFrame-method #' @rdname rollup #' @name rollup -#' @export #' @examples #'\dontrun{ #' df <- createDataFrame(mtcars) @@ -3897,7 +3915,6 @@ setMethod("rollup", #' @aliases hint,SparkDataFrame,character-method #' @rdname hint #' @name hint -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3921,7 +3938,6 @@ setMethod("hint", #' @family SparkDataFrame functions #' @rdname alias #' @name alias -#' @export #' @examples #' \dontrun{ #' df <- alias(createDataFrame(mtcars), "mtcars") @@ -3952,7 +3968,6 @@ setMethod("alias", #' @family SparkDataFrame functions #' @rdname broadcast #' @name broadcast -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3967,3 +3982,46 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb..429dd5d56549 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -123,7 +123,6 @@ infer_type <- function(x) { #' @return a list of config values with keys as their names #' @rdname sparkR.conf #' @name sparkR.conf -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -163,7 +162,6 @@ sparkR.conf <- function(key, defaultValue) { #' @return a character string of the Spark version #' @rdname sparkR.version #' @name sparkR.version -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -191,7 +189,6 @@ getDefaultSqlSource <- function() { #' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -294,7 +291,6 @@ createDataFrame <- function(x, ...) { #' @rdname createDataFrame #' @aliases createDataFrame -#' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { @@ -304,7 +300,6 @@ as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPa #' @param ... additional argument(s). #' @rdname createDataFrame #' @aliases as.DataFrame -#' @export as.DataFrame <- function(data, ...) { dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } @@ -342,7 +337,6 @@ setMethod("toDF", signature(x = "RDD"), #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -371,7 +365,6 @@ read.json <- function(x, ...) { #' @rdname read.json #' @name jsonFile -#' @export #' @method jsonFile default #' @note jsonFile since 1.4.0 jsonFile.default <- function(path) { @@ -423,7 +416,6 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc -#' @export #' @name read.orc #' @note read.orc since 2.0.0 read.orc <- function(path, ...) { @@ -444,7 +436,6 @@ read.orc <- function(path, ...) { #' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @export #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 @@ -466,7 +457,6 @@ read.parquet <- function(x, ...) { #' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile -#' @export #' @method parquetFile default #' @note parquetFile since 1.4.0 parquetFile.default <- function(...) { @@ -490,7 +480,6 @@ parquetFile <- function(x, ...) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -522,7 +511,6 @@ read.text <- function(x, ...) { #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame #' @rdname sql -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -556,7 +544,6 @@ sql <- function(x, ...) { #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -591,7 +578,6 @@ tableToDF <- function(tableName) { #' @rdname read.df #' @name read.df #' @seealso \link{read.json} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -681,7 +667,6 @@ loadDF <- function(x = NULL, ...) { #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -727,12 +712,13 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream #' @seealso \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index debc7cbde55e..ee7f4adf726e 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{windowPartitionBy}, \link{windowOrderBy} #' #' @param sws A Java object reference to the backing Scala WindowSpec -#' @export #' @note WindowSpec since 2.0.0 setClass("WindowSpec", slots = list(sws = "jobj")) @@ -44,7 +43,6 @@ windowSpec <- function(sws) { } #' @rdname show -#' @export #' @note show(WindowSpec) since 2.0.0 setMethod("show", "WindowSpec", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "WindowSpec", #' @name partitionBy #' @aliases partitionBy,WindowSpec-method #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' partitionBy(ws, "col1", "col2") @@ -97,7 +94,6 @@ setMethod("partitionBy", #' @aliases orderBy,WindowSpec,character-method #' @family windowspec_method #' @seealso See \link{arrange} for use in sorting a SparkDataFrame -#' @export #' @examples #' \dontrun{ #' orderBy(ws, "col1", "col2") @@ -113,7 +109,6 @@ setMethod("orderBy", #' @rdname orderBy #' @name orderBy #' @aliases orderBy,WindowSpec,Column-method -#' @export #' @note orderBy(WindowSpec, Column) since 2.0.0 setMethod("orderBy", signature(x = "WindowSpec", col = "Column"), @@ -142,7 +137,6 @@ setMethod("orderBy", #' @aliases rowsBetween,WindowSpec,numeric,numeric-method #' @name rowsBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rowsBetween(ws, 0, 3) @@ -174,7 +168,6 @@ setMethod("rowsBetween", #' @aliases rangeBetween,WindowSpec,numeric,numeric-method #' @name rangeBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rangeBetween(ws, 0, 3) @@ -202,7 +195,6 @@ setMethod("rangeBetween", #' @name over #' @aliases over,Column,WindowSpec-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 398dffc4ab1b..282f8a685773 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -32,14 +32,12 @@ # @seealso broadcast # # @param id Id of the backing Spark broadcast variable -# @export setClass("Broadcast", slots = list(id = "character")) # @rdname broadcast-class # @param value Value of the broadcast variable # @param jBroadcastRef reference to the backing Java broadcast object # @param objName name of broadcasted object -# @export Broadcast <- function(id, value, jBroadcastRef, objName) { .broadcastValues[[id]] <- value .broadcastNames[[as.character(objName)]] <- jBroadcastRef @@ -73,7 +71,6 @@ setMethod("value", # @param bcastId The id of broadcast variable to set # @param value The value to be set -# @export setBroadcastValue <- function(bcastId, value) { bcastIdStr <- as.character(bcastId) .broadcastValues[[bcastIdStr]] <- value diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index e59a7024333a..baf4d861fcf8 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -34,7 +34,6 @@ #' @return A SparkDataFrame. #' @rdname createExternalTable-deprecated #' @seealso \link{createTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -71,7 +70,6 @@ createExternalTable <- function(x, ...) { #' @return A SparkDataFrame. #' @rdname createTable #' @seealso \link{createExternalTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -110,7 +108,6 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, .. #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -140,7 +137,6 @@ cacheTable <- function(x, ...) { #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -167,7 +163,6 @@ uncacheTable <- function(x, ...) { #' Removes all cached tables from the in-memory cache. #' #' @rdname clearCache -#' @export #' @examples #' \dontrun{ #' clearCache() @@ -193,7 +188,6 @@ clearCache <- function() { #' @param tableName The name of the SparkSQL table to be dropped. #' @seealso \link{dropTempView} #' @rdname dropTempTable-deprecated -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -225,7 +219,6 @@ dropTempTable <- function(x, ...) { #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +244,6 @@ dropTempView <- function(viewName) { #' @return a SparkDataFrame #' @rdname tables #' @seealso \link{listTables} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -276,7 +268,6 @@ tables <- function(x, ...) { #' @param databaseName (optional) name of the database #' @return a list of table names #' @rdname tableNames -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -304,7 +295,6 @@ tableNames <- function(x, ...) { #' @return name of the current default database. #' @rdname currentDatabase #' @name currentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -324,7 +314,6 @@ currentDatabase <- function() { #' @param databaseName name of the database #' @rdname setCurrentDatabase #' @name setCurrentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -347,7 +336,6 @@ setCurrentDatabase <- function(databaseName) { #' @return a SparkDataFrame of the list of databases. #' @rdname listDatabases #' @name listDatabases -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -370,7 +358,6 @@ listDatabases <- function() { #' @rdname listTables #' @name listTables #' @seealso \link{tables} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -403,7 +390,6 @@ listTables <- function(databaseName = NULL) { #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns #' @name listColumns -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -433,7 +419,6 @@ listColumns <- function(tableName, databaseName = NULL) { #' @return a SparkDataFrame of the list of function descriptions. #' @rdname listFunctions #' @name listFunctions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -463,7 +448,6 @@ listFunctions <- function(databaseName = NULL) { #' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -490,7 +474,6 @@ recoverPartitions <- function(tableName) { #' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -512,7 +495,6 @@ refreshTable <- function(tableName) { #' @param path the path of the data source. #' @rdname refreshByPath #' @name refreshByPath -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 3095adb918b6..7926a9a2467e 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -29,7 +29,6 @@ setOldClass("jobj") #' @rdname column #' #' @slot jc reference to JVM SparkDataFrame column -#' @export #' @note Column since 1.4.0 setClass("Column", slots = list(jc = "jobj")) @@ -56,7 +55,6 @@ setMethod("column", #' @rdname show #' @name show #' @aliases show,Column-method -#' @export #' @note show(Column) since 1.4.0 setMethod("show", "Column", function(object) { @@ -134,7 +132,6 @@ createMethods() #' @name alias #' @aliases alias,Column-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -164,12 +161,18 @@ setMethod("alias", #' @aliases substr,Column-method #' #' @param x a Column. -#' @param start starting position. +#' @param start starting position. It should be 1-base. #' @param stop ending position. +#' @examples +#' \dontrun{ +#' df <- createDataFrame(list(list(a="abcdef"))) +#' collect(select(df, substr(df$a, 1, 4))) # the result is `abcd`. +#' collect(select(df, substr(df$a, 2, 4))) # the result is `bcd`. +#' } #' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { - jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + jc <- callJMethod(x@jc, "substr", as.integer(start), as.integer(stop - start + 1)) column(jc) }) @@ -270,7 +273,6 @@ setMethod("cast", #' @name %in% #' @aliases %in%,Column-method #' @return A matched values as a result of comparing with given values. -#' @export #' @examples #' \dontrun{ #' filter(df, "age in (10, 30)") @@ -296,7 +298,6 @@ setMethod("%in%", #' @name otherwise #' @family colum_func #' @aliases otherwise,Column-method -#' @export #' @note otherwise since 1.5.0 setMethod("otherwise", signature(x = "Column", value = "ANY"), @@ -318,7 +319,6 @@ setMethod("otherwise", #' @rdname eq_null_safe #' @name %<=>% #' @aliases %<=>%,Column-method -#' @export #' @examples #' \dontrun{ #' df1 <- createDataFrame(data.frame( @@ -348,7 +348,6 @@ setMethod("%<=>%", #' @rdname not #' @name not #' @aliases !,Column-method -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ac..8ec727dd042b 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -308,7 +308,6 @@ setCheckpointDirSC <- function(sc, dirName) { #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. -#' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") @@ -323,7 +322,6 @@ spark.addFile <- function(path, recursive = FALSE) { #' #' @rdname spark.getSparkFilesRootDirectory #' @return the root directory that contains files added through spark.addFile -#' @export #' @examples #'\dontrun{ #' spark.getSparkFilesRootDirectory() @@ -344,7 +342,6 @@ spark.getSparkFilesRootDirectory <- function() { # nolint #' @rdname spark.getSparkFiles #' @param fileName The name of the file added through spark.addFile #' @return the absolute path of a file added through spark.addFile. -#' @export #' @examples #'\dontrun{ #' spark.getSparkFiles("myfile") @@ -391,7 +388,6 @@ spark.getSparkFiles <- function(fileName) { #' @param list the list of elements #' @param func a function that takes one argument. #' @return a list of results (the exact type being determined by the function) -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -412,7 +408,6 @@ spark.lapply <- function(list, func) { #' #' @rdname setLogLevel #' @param level New log level -#' @export #' @examples #'\dontrun{ #' setLogLevel("ERROR") @@ -431,7 +426,6 @@ setLogLevel <- function(level) { #' @rdname setCheckpointDir #' @param directory Directory path to checkpoint to #' @seealso \link{checkpoint} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 55365a41d774..0ec99d19e21e 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,11 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. @@ -201,6 +206,9 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -208,7 +216,8 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} NULL #' Window functions for Column operations @@ -244,7 +253,6 @@ NULL #' If the parameter is a Column, it is returned unchanged. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases lit lit,ANY-method #' @examples #' @@ -267,7 +275,6 @@ setMethod("lit", signature("ANY"), #' \code{abs}: Computes the absolute value. #' #' @rdname column_math_functions -#' @export #' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", @@ -278,11 +285,10 @@ setMethod("abs", }) #' @details -#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in -#' the range 0.0 through pi. +#' \code{acos}: Returns the inverse cosine of the given value, +#' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions -#' @export #' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", @@ -296,7 +302,6 @@ setMethod("acos", #' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' #' @rdname column_aggregate_functions -#' @export #' @aliases approxCountDistinct approxCountDistinct,Column-method #' @examples #' @@ -319,7 +324,6 @@ setMethod("approxCountDistinct", #' and returns the result as an int column. #' #' @rdname column_string_functions -#' @export #' @aliases ascii ascii,Column-method #' @examples #' @@ -334,11 +338,10 @@ setMethod("ascii", }) #' @details -#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in -#' the range -pi/2 through pi/2. +#' \code{asin}: Returns the inverse sine of the given value, +#' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions -#' @export #' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", @@ -349,11 +352,10 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. +#' \code{atan}: Returns the inverse tangent of the given value, +#' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions -#' @export #' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", @@ -370,7 +372,6 @@ setMethod("atan", #' @rdname avg #' @name avg #' @family aggregate functions -#' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} #' @note avg since 1.4.0 @@ -386,7 +387,6 @@ setMethod("avg", #' a string column. This is the reverse of unbase64. #' #' @rdname column_string_functions -#' @export #' @aliases base64 base64,Column-method #' @examples #' @@ -410,7 +410,6 @@ setMethod("base64", #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions -#' @export #' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", @@ -424,7 +423,6 @@ setMethod("bin", #' \code{bitwiseNOT}: Computes bitwise NOT. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases bitwiseNOT bitwiseNOT,Column-method #' @examples #' @@ -442,7 +440,6 @@ setMethod("bitwiseNOT", #' \code{cbrt}: Computes the cube-root of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", @@ -456,7 +453,6 @@ setMethod("cbrt", #' \code{ceil}: Computes the ceiling of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", @@ -471,7 +467,6 @@ setMethod("ceil", #' #' @rdname column_math_functions #' @aliases ceiling ceiling,Column-method -#' @export #' @note ceiling since 1.5.0 setMethod("ceiling", signature(x = "Column"), @@ -483,7 +478,6 @@ setMethod("ceiling", #' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases coalesce,Column-method #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", @@ -514,7 +508,6 @@ col <- function(x) { #' @rdname column #' @name column #' @family non-aggregate functions -#' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} #' @note column since 1.6.0 @@ -533,7 +526,6 @@ setMethod("column", #' @rdname corr #' @name corr #' @family aggregate functions -#' @export #' @aliases corr,Column-method #' @examples #' \dontrun{ @@ -557,7 +549,6 @@ setMethod("corr", signature(x = "Column"), #' @rdname cov #' @name cov #' @family aggregate functions -#' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ @@ -598,7 +589,6 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname cov #' @name covar_pop -#' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), @@ -613,11 +603,11 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. Units in radians. +#' \code{cos}: Returns the cosine of the given value, +#' as if computed by \code{java.lang.Math.cos()}. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method -#' @export #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -627,11 +617,11 @@ setMethod("cos", }) #' @details -#' \code{cosh}: Computes the hyperbolic cosine of the given value. +#' \code{cosh}: Returns the hyperbolic cosine of the given value, +#' as if computed by \code{java.lang.Math.cosh()}. #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method -#' @export #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -649,7 +639,6 @@ setMethod("cosh", #' @name count #' @family aggregate functions #' @aliases count,Column-method -#' @export #' @examples \dontrun{count(df$c)} #' @note count since 1.4.0 setMethod("count", @@ -665,7 +654,6 @@ setMethod("count", #' #' @rdname column_misc_functions #' @aliases crc32 crc32,Column-method -#' @export #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -680,7 +668,6 @@ setMethod("crc32", #' #' @rdname column_misc_functions #' @aliases hash hash,Column-method -#' @export #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -699,7 +686,6 @@ setMethod("hash", #' #' @rdname column_datetime_functions #' @aliases dayofmonth dayofmonth,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -721,7 +707,6 @@ setMethod("dayofmonth", #' #' @rdname column_datetime_functions #' @aliases dayofweek dayofweek,Column-method -#' @export #' @note dayofweek since 2.3.0 setMethod("dayofweek", signature(x = "Column"), @@ -736,7 +721,6 @@ setMethod("dayofweek", #' #' @rdname column_datetime_functions #' @aliases dayofyear dayofyear,Column-method -#' @export #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -754,7 +738,6 @@ setMethod("dayofyear", #' #' @rdname column_string_functions #' @aliases decode decode,Column,character-method -#' @export #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -769,7 +752,6 @@ setMethod("decode", #' #' @rdname column_string_functions #' @aliases encode encode,Column,character-method -#' @export #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -783,7 +765,6 @@ setMethod("encode", #' #' @rdname column_math_functions #' @aliases exp exp,Column-method -#' @export #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -797,7 +778,6 @@ setMethod("exp", #' #' @rdname column_math_functions #' @aliases expm1 expm1,Column-method -#' @export #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -811,7 +791,6 @@ setMethod("expm1", #' #' @rdname column_math_functions #' @aliases factorial factorial,Column-method -#' @export #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -834,7 +813,6 @@ setMethod("factorial", #' @name first #' @aliases first,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' first(df$c) @@ -858,7 +836,6 @@ setMethod("first", #' #' @rdname column_math_functions #' @aliases floor floor,Column-method -#' @export #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -872,7 +849,6 @@ setMethod("floor", #' #' @rdname column_math_functions #' @aliases hex hex,Column-method -#' @export #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -886,7 +862,6 @@ setMethod("hex", #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -909,7 +884,6 @@ setMethod("hour", #' #' @rdname column_string_functions #' @aliases initcap initcap,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -944,7 +918,6 @@ setMethod("isnan", #' #' @rdname column_nonaggregate_functions #' @aliases is.nan is.nan,Column-method -#' @export #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -957,7 +930,6 @@ setMethod("is.nan", #' #' @rdname column_aggregate_functions #' @aliases kurtosis kurtosis,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -986,7 +958,6 @@ setMethod("kurtosis", #' @name last #' @aliases last,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' last(df$c) @@ -1012,7 +983,6 @@ setMethod("last", #' #' @rdname column_datetime_functions #' @aliases last_day last_day,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1026,11 +996,12 @@ setMethod("last_day", }) #' @details -#' \code{length}: Computes the length of a given string or binary column. +#' \code{length}: Computes the character length of a string data or number of bytes +#' of a binary data. The length of string data includes the trailing spaces. +#' The length of binary data includes binary zeros. #' #' @rdname column_string_functions #' @aliases length length,Column-method -#' @export #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -1044,7 +1015,6 @@ setMethod("length", #' #' @rdname column_math_functions #' @aliases log log,Column-method -#' @export #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1058,7 +1028,6 @@ setMethod("log", #' #' @rdname column_math_functions #' @aliases log10 log10,Column-method -#' @export #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1072,7 +1041,6 @@ setMethod("log10", #' #' @rdname column_math_functions #' @aliases log1p log1p,Column-method -#' @export #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1086,7 +1054,6 @@ setMethod("log1p", #' #' @rdname column_math_functions #' @aliases log2 log2,Column-method -#' @export #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1100,7 +1067,6 @@ setMethod("log2", #' #' @rdname column_string_functions #' @aliases lower lower,Column-method -#' @export #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1115,7 +1081,6 @@ setMethod("lower", #' #' @rdname column_string_functions #' @aliases ltrim ltrim,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1139,7 +1104,6 @@ setMethod("ltrim", #' @param trimString a character string to trim with #' @rdname column_string_functions #' @aliases ltrim,Column,character-method -#' @export #' @note ltrim(Column, character) since 2.3.0 setMethod("ltrim", signature(x = "Column", trimString = "character"), @@ -1167,7 +1131,6 @@ setMethod("max", #' #' @rdname column_misc_functions #' @aliases md5 md5,Column-method -#' @export #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1181,7 +1144,6 @@ setMethod("md5", #' #' @rdname column_aggregate_functions #' @aliases mean mean,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1207,7 +1169,6 @@ setMethod("mean", #' #' @rdname column_aggregate_functions #' @aliases min min,Column-method -#' @export #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1221,7 +1182,6 @@ setMethod("min", #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method -#' @export #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1244,7 +1204,6 @@ setMethod("minute", #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, monotonically_increasing_id()))} @@ -1260,7 +1219,6 @@ setMethod("monotonically_increasing_id", #' #' @rdname column_datetime_functions #' @aliases month month,Column-method -#' @export #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1274,7 +1232,6 @@ setMethod("month", #' #' @rdname column_nonaggregate_functions #' @aliases negate negate,Column-method -#' @export #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1288,7 +1245,6 @@ setMethod("negate", #' #' @rdname column_datetime_functions #' @aliases quarter quarter,Column-method -#' @export #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1302,7 +1258,6 @@ setMethod("quarter", #' #' @rdname column_string_functions #' @aliases reverse reverse,Column-method -#' @export #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1317,7 +1272,6 @@ setMethod("reverse", #' #' @rdname column_math_functions #' @aliases rint rint,Column-method -#' @export #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1332,7 +1286,6 @@ setMethod("rint", #' #' @rdname column_math_functions #' @aliases round round,Column-method -#' @export #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1352,7 +1305,6 @@ setMethod("round", #' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method -#' @export #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1367,7 +1319,6 @@ setMethod("bround", #' #' @rdname column_string_functions #' @aliases rtrim rtrim,Column,missing-method -#' @export #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column", trimString = "missing"), @@ -1378,7 +1329,6 @@ setMethod("rtrim", #' @rdname column_string_functions #' @aliases rtrim,Column,character-method -#' @export #' @note rtrim(Column, character) since 2.3.0 setMethod("rtrim", signature(x = "Column", trimString = "character"), @@ -1392,7 +1342,6 @@ setMethod("rtrim", #' #' @rdname column_aggregate_functions #' @aliases sd sd,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1410,7 +1359,6 @@ setMethod("sd", #' #' @rdname column_datetime_functions #' @aliases second second,Column-method -#' @export #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1425,7 +1373,6 @@ setMethod("second", #' #' @rdname column_misc_functions #' @aliases sha1 sha1,Column-method -#' @export #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -1439,7 +1386,6 @@ setMethod("sha1", #' #' @rdname column_math_functions #' @aliases signum signum,Column-method -#' @export #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1453,7 +1399,6 @@ setMethod("signum", #' #' @rdname column_math_functions #' @aliases sign sign,Column-method -#' @export #' @note sign since 1.5.0 setMethod("sign", signature(x = "Column"), function(x) { @@ -1461,11 +1406,11 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. Units in radians. +#' \code{sin}: Returns the sine of the given value, +#' as if computed by \code{java.lang.Math.sin()}. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method -#' @export #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1475,11 +1420,11 @@ setMethod("sin", }) #' @details -#' \code{sinh}: Computes the hyperbolic sine of the given value. +#' \code{sinh}: Returns the hyperbolic sine of the given value, +#' as if computed by \code{java.lang.Math.sinh()}. #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method -#' @export #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1493,7 +1438,6 @@ setMethod("sinh", #' #' @rdname column_aggregate_functions #' @aliases skewness skewness,Column-method -#' @export #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1507,7 +1451,6 @@ setMethod("skewness", #' #' @rdname column_string_functions #' @aliases soundex soundex,Column-method -#' @export #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1524,7 +1467,6 @@ setMethod("soundex", #' #' @rdname column_nonaggregate_functions #' @aliases spark_partition_id spark_partition_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, spark_partition_id()))} @@ -1554,7 +1496,6 @@ setMethod("stddev", #' #' @rdname column_aggregate_functions #' @aliases stddev_pop stddev_pop,Column-method -#' @export #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1568,7 +1509,6 @@ setMethod("stddev_pop", #' #' @rdname column_aggregate_functions #' @aliases stddev_samp stddev_samp,Column-method -#' @export #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1582,7 +1522,6 @@ setMethod("stddev_samp", #' #' @rdname column_nonaggregate_functions #' @aliases struct struct,characterOrColumn-method -#' @export #' @examples #' #' \dontrun{ @@ -1608,7 +1547,6 @@ setMethod("struct", #' #' @rdname column_math_functions #' @aliases sqrt sqrt,Column-method -#' @export #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1622,7 +1560,6 @@ setMethod("sqrt", #' #' @rdname column_aggregate_functions #' @aliases sum sum,Column-method -#' @export #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1636,7 +1573,6 @@ setMethod("sum", #' #' @rdname column_aggregate_functions #' @aliases sumDistinct sumDistinct,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1651,11 +1587,12 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. Units in radians. +#' \code{tan}: Returns the tangent of the given value, +#' as if computed by \code{java.lang.Math.tan()}. +#' Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method -#' @export #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1665,11 +1602,11 @@ setMethod("tan", }) #' @details -#' \code{tanh}: Computes the hyperbolic tangent of the given value. +#' \code{tanh}: Returns the hyperbolic tangent of the given value, +#' as if computed by \code{java.lang.Math.tanh()}. #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method -#' @export #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1684,7 +1621,6 @@ setMethod("tanh", #' #' @rdname column_math_functions #' @aliases toDegrees toDegrees,Column-method -#' @export #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1699,7 +1635,6 @@ setMethod("toDegrees", #' #' @rdname column_math_functions #' @aliases toRadians toRadians,Column-method -#' @export #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1719,7 +1654,6 @@ setMethod("toRadians", #' #' @rdname column_datetime_functions #' @aliases to_date to_date,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1740,7 +1674,6 @@ setMethod("to_date", #' @rdname column_datetime_functions #' @aliases to_date,Column,character-method -#' @export #' @note to_date(Column, character) since 2.2.0 setMethod("to_date", signature(x = "Column", format = "character"), @@ -1756,7 +1689,6 @@ setMethod("to_date", #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1794,7 +1726,6 @@ setMethod("to_json", signature(x = "Column"), #' #' @rdname column_datetime_functions #' @aliases to_timestamp to_timestamp,Column,missing-method -#' @export #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1805,7 +1736,6 @@ setMethod("to_timestamp", #' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method -#' @export #' @note to_timestamp(Column, character) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "character"), @@ -1820,7 +1750,6 @@ setMethod("to_timestamp", #' #' @rdname column_string_functions #' @aliases trim trim,Column,missing-method -#' @export #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column", trimString = "missing"), @@ -1831,7 +1760,6 @@ setMethod("trim", #' @rdname column_string_functions #' @aliases trim,Column,character-method -#' @export #' @note trim(Column, character) since 2.3.0 setMethod("trim", signature(x = "Column", trimString = "character"), @@ -1846,7 +1774,6 @@ setMethod("trim", #' #' @rdname column_string_functions #' @aliases unbase64 unbase64,Column-method -#' @export #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1861,7 +1788,6 @@ setMethod("unbase64", #' #' @rdname column_math_functions #' @aliases unhex unhex,Column-method -#' @export #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -1875,7 +1801,6 @@ setMethod("unhex", #' #' @rdname column_string_functions #' @aliases upper upper,Column-method -#' @export #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1889,7 +1814,6 @@ setMethod("upper", #' #' @rdname column_aggregate_functions #' @aliases var var,Column-method -#' @export #' @examples #' #'\dontrun{ @@ -1904,7 +1828,6 @@ setMethod("var", #' @rdname column_aggregate_functions #' @aliases variance variance,Column-method -#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1918,7 +1841,6 @@ setMethod("variance", #' #' @rdname column_aggregate_functions #' @aliases var_pop var_pop,Column-method -#' @export #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -1932,7 +1854,6 @@ setMethod("var_pop", #' #' @rdname column_aggregate_functions #' @aliases var_samp var_samp,Column-method -#' @export #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -1946,7 +1867,6 @@ setMethod("var_samp", #' #' @rdname column_datetime_functions #' @aliases weekofyear weekofyear,Column-method -#' @export #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -1960,7 +1880,6 @@ setMethod("weekofyear", #' #' @rdname column_datetime_functions #' @aliases year year,Column-method -#' @export #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -1971,11 +1890,11 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). Units in radians. +#' (x, y) to polar coordinates (r, theta), +#' as if computed by \code{java.lang.Math.atan2()}. Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method -#' @export #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -1991,7 +1910,6 @@ setMethod("atan2", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2015,7 +1933,6 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases hypot hypot,Column-method -#' @export #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2031,7 +1948,6 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname column_string_functions #' @aliases levenshtein levenshtein,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2054,7 +1970,6 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method -#' @export #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2072,7 +1987,6 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method -#' @export #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2089,7 +2003,6 @@ setMethod("nanvl", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases pmod pmod,Column-method -#' @export #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2104,7 +2017,6 @@ setMethod("pmod", signature(y = "Column"), #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method -#' @export #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2118,7 +2030,6 @@ setMethod("approxCountDistinct", #' #' @rdname column_aggregate_functions #' @aliases countDistinct countDistinct,Column-method -#' @export #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2138,7 +2049,6 @@ setMethod("countDistinct", #' #' @rdname column_string_functions #' @aliases concat concat,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2167,7 +2077,6 @@ setMethod("concat", #' #' @rdname column_nonaggregate_functions #' @aliases greatest greatest,Column-method -#' @export #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2187,7 +2096,6 @@ setMethod("greatest", #' #' @rdname column_nonaggregate_functions #' @aliases least least,Column-method -#' @export #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2206,7 +2114,6 @@ setMethod("least", #' #' @rdname column_aggregate_functions #' @aliases n_distinct n_distinct,Column-method -#' @export #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -2216,7 +2123,6 @@ setMethod("n_distinct", signature(x = "Column"), #' @rdname count #' @name n #' @aliases n,Column-method -#' @export #' @examples \dontrun{n(df$c)} #' @note n since 1.4.0 setMethod("n", signature(x = "Column"), @@ -2235,7 +2141,6 @@ setMethod("n", signature(x = "Column"), #' @rdname column_datetime_diff_functions #' #' @aliases date_format date_format,Column,character-method -#' @export #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2253,7 +2158,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method -#' @export #' @examples #' #' \dontrun{ @@ -2296,7 +2200,6 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") #' @rdname column_datetime_diff_functions #' #' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2318,7 +2221,6 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_string_functions #' @aliases instr instr,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2341,7 +2243,6 @@ setMethod("instr", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases next_day next_day,Column,character-method -#' @export #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2356,7 +2257,6 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method -#' @export #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2369,7 +2269,6 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases add_months add_months,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2390,7 +2289,6 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' #' @rdname column_datetime_diff_functions #' @aliases date_add date_add,Column,numeric-method -#' @export #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2404,7 +2302,6 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname column_datetime_diff_functions #' #' @aliases date_sub date_sub,Column,numeric-method -#' @export #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2421,7 +2318,6 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' #' @rdname column_string_functions #' @aliases format_number format_number,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2444,7 +2340,6 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @rdname column_misc_functions #' @aliases sha2 sha2,Column,numeric-method -#' @export #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2458,7 +2353,6 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftLeft shiftLeft,Column,numeric-method -#' @export #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2474,7 +2368,6 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method -#' @export #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2490,7 +2383,6 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method -#' @export #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2507,7 +2399,6 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method -#' @export #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2523,7 +2414,6 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param toBase base to convert to. #' @rdname column_math_functions #' @aliases conv conv,Column,numeric,numeric-method -#' @export #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { @@ -2541,7 +2431,6 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' #' @rdname column_nonaggregate_functions #' @aliases expr expr,character-method -#' @export #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2556,7 +2445,6 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @rdname column_string_functions #' @aliases format_string format_string,character,Column-method -#' @export #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2577,7 +2465,6 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname column_datetime_functions #' #' @aliases from_unixtime from_unixtime,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2619,7 +2506,6 @@ setMethod("from_unixtime", signature(x = "Column"), #' \code{startTime} as \code{"15 minutes"}. #' @rdname column_datetime_functions #' @aliases window window,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2670,7 +2556,6 @@ setMethod("window", signature(x = "Column"), #' @param pos start position of search. #' @rdname column_string_functions #' @aliases locate locate,character,Column-method -#' @export #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2687,7 +2572,6 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param pad a character string to be padded with. #' @rdname column_string_functions #' @aliases lpad lpad,Column,numeric,character-method -#' @export #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2704,7 +2588,6 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. #' @aliases rand rand,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -2719,7 +2602,6 @@ setMethod("rand", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method -#' @export #' @note rand(numeric) since 1.5.0 setMethod("rand", signature(seed = "numeric"), function(seed) { @@ -2733,7 +2615,6 @@ setMethod("rand", signature(seed = "numeric"), #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method -#' @export #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2743,7 +2624,6 @@ setMethod("randn", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method -#' @export #' @note randn(numeric) since 1.5.0 setMethod("randn", signature(seed = "numeric"), function(seed) { @@ -2760,7 +2640,6 @@ setMethod("randn", signature(seed = "numeric"), #' @param idx a group index. #' @rdname column_string_functions #' @aliases regexp_extract regexp_extract,Column,character,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2789,7 +2668,6 @@ setMethod("regexp_extract", #' @param replacement a character string that a matched \code{pattern} is replaced with. #' @rdname column_string_functions #' @aliases regexp_replace regexp_replace,Column,character,character-method -#' @export #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2805,7 +2683,6 @@ setMethod("regexp_replace", #' #' @rdname column_string_functions #' @aliases rpad rpad,Column,numeric,character-method -#' @export #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2828,7 +2705,6 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' counting from the right. #' @rdname column_string_functions #' @aliases substring_index substring_index,Column,character,numeric-method -#' @export #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2851,7 +2727,6 @@ setMethod("substring_index", #' at the same location, if any. #' @rdname column_string_functions #' @aliases translate translate,Column,character,character-method -#' @export #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -2866,7 +2741,6 @@ setMethod("translate", #' #' @rdname column_datetime_functions #' @aliases unix_timestamp unix_timestamp,missing,missing-method -#' @export #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -2876,7 +2750,6 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method -#' @export #' @note unix_timestamp(Column) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { @@ -2886,7 +2759,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method -#' @export #' @note unix_timestamp(Column, character) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -2902,7 +2774,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. #' @aliases when when,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2931,7 +2802,6 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. #' @aliases ifelse ifelse,Column-method -#' @export #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -2957,7 +2827,6 @@ setMethod("ifelse", #' #' @rdname column_window_functions #' @aliases cume_dist cume_dist,missing-method -#' @export #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2978,7 +2847,6 @@ setMethod("cume_dist", #' #' @rdname column_window_functions #' @aliases dense_rank dense_rank,missing-method -#' @export #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -2995,7 +2863,6 @@ setMethod("dense_rank", #' #' @rdname column_window_functions #' @aliases lag lag,characterOrColumn-method -#' @export #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -3020,7 +2887,6 @@ setMethod("lag", #' #' @rdname column_window_functions #' @aliases lead lead,characterOrColumn,numeric-method -#' @export #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -3044,7 +2910,6 @@ setMethod("lead", #' #' @rdname column_window_functions #' @aliases ntile ntile,numeric-method -#' @export #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3062,7 +2927,6 @@ setMethod("ntile", #' #' @rdname column_window_functions #' @aliases percent_rank percent_rank,missing-method -#' @export #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3083,7 +2947,6 @@ setMethod("percent_rank", #' #' @rdname column_window_functions #' @aliases rank rank,missing-method -#' @export #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3094,7 +2957,6 @@ setMethod("rank", #' @rdname column_window_functions #' @aliases rank,ANY-method -#' @export setMethod("rank", signature(x = "ANY"), function(x, ...) { @@ -3108,7 +2970,6 @@ setMethod("rank", #' #' @rdname column_window_functions #' @aliases row_number row_number,missing-method -#' @export #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3123,10 +2984,8 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method -#' @export #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3135,12 +2994,66 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + +#' @details +#' \code{flatten}: Transforms an array of arrays into a single array. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' #' @rdname column_collection_functions #' @aliases map_keys map_keys,Column-method -#' @export #' @note map_keys since 2.3.0 setMethod("map_keys", signature(x = "Column"), @@ -3154,7 +3067,6 @@ setMethod("map_keys", #' #' @rdname column_collection_functions #' @aliases map_values map_values,Column-method -#' @export #' @note map_values since 2.3.0 setMethod("map_values", signature(x = "Column"), @@ -3163,12 +3075,27 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' #' @rdname column_collection_functions #' @aliases explode explode,Column-method -#' @export #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3182,7 +3109,6 @@ setMethod("explode", #' #' @rdname column_collection_functions #' @aliases size size,Column-method -#' @export #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3200,7 +3126,6 @@ setMethod("size", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method -#' @export #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3215,7 +3140,6 @@ setMethod("sort_array", #' #' @rdname column_collection_functions #' @aliases posexplode posexplode,Column-method -#' @export #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3230,7 +3154,6 @@ setMethod("posexplode", #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method -#' @export #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3251,7 +3174,6 @@ setMethod("create_array", #' #' @rdname column_nonaggregate_functions #' @aliases create_map create_map,Column-method -#' @export #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3269,7 +3191,6 @@ setMethod("create_map", #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3289,7 +3210,6 @@ setMethod("collect_list", #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method -#' @export #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3304,7 +3224,6 @@ setMethod("collect_set", #' #' @rdname column_string_functions #' @aliases split_string split_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3327,7 +3246,6 @@ setMethod("split_string", #' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3350,7 +3268,6 @@ setMethod("repeat_string", #' #' @rdname column_collection_functions #' @aliases explode_outer explode_outer,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3375,7 +3292,6 @@ setMethod("explode_outer", #' #' @rdname column_collection_functions #' @aliases posexplode_outer posexplode_outer,Column-method -#' @export #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), @@ -3396,7 +3312,6 @@ setMethod("posexplode_outer", #' @name not #' @aliases not,Column-method #' @family non-aggregate functions -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -3424,7 +3339,6 @@ setMethod("not", #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3457,7 +3371,6 @@ setMethod("grouping_bit", #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3492,7 +3405,6 @@ setMethod("grouping_id", #' #' @rdname column_nonaggregate_functions #' @aliases input_file_name input_file_name,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -3510,7 +3422,6 @@ setMethod("input_file_name", signature("missing"), #' #' @rdname column_datetime_functions #' @aliases trunc trunc,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3530,7 +3441,6 @@ setMethod("trunc", #' #' @rdname column_datetime_functions #' @aliases date_trunc date_trunc,character,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3549,7 +3459,6 @@ setMethod("date_trunc", #' #' @rdname column_datetime_functions #' @aliases current_date current_date,missing-method -#' @export #' @examples #' \dontrun{ #' head(select(df, current_date(), current_timestamp()))} @@ -3566,7 +3475,6 @@ setMethod("current_date", #' #' @rdname column_datetime_functions #' @aliases current_timestamp current_timestamp,missing-method -#' @export #' @note current_timestamp since 2.3.0 setMethod("current_timestamp", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5..4ef12d19b357 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -19,7 +19,6 @@ # @rdname aggregateRDD # @seealso reduce -# @export setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) @@ -27,21 +26,17 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition -# @export setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods -# @export setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods -# @export setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) # @rdname collect-methods -# @export setGeneric("collectPartition", function(x, partitionId) { standardGeneric("collectPartition") @@ -52,19 +47,15 @@ setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue -# @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @rdname crosstab -# @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @rdname freqItems -# @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) # @rdname approxQuantile -# @export setGeneric("approxQuantile", function(x, cols, probabilities, relativeError) { standardGeneric("approxQuantile") @@ -73,18 +64,15 @@ setGeneric("approxQuantile", setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD -# @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap -# @export setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @rdname fold # @seealso reduce -# @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) @@ -95,17 +83,14 @@ setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachParti setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) # @rdname glom -# @export setGeneric("glom", function(x) { standardGeneric("glom") }) # @rdname histogram -# @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) # @rdname keyBy -# @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) @@ -123,47 +108,37 @@ setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) # @rdname maximum -# @export setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @rdname minimum -# @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) # @rdname sumRDD -# @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @rdname name -# @export setGeneric("name", function(x) { standardGeneric("name") }) # @rdname getNumPartitionsRDD -# @export setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions -# @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD -# @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) # @rdname pivot -# @export setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) # @rdname reduce -# @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD -# @export setGeneric("sampleRDD", function(x, withReplacement, fraction, seed) { standardGeneric("sampleRDD") @@ -171,21 +146,17 @@ setGeneric("sampleRDD", # @rdname saveAsObjectFile # @seealso objectFile -# @export setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) # @rdname saveAsTextFile -# @export setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) # @rdname setName -# @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) # @rdname sortBy -# @export setGeneric("sortBy", function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") @@ -194,88 +165,71 @@ setGeneric("sortBy", setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered -# @export setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) # @rdname takeSample -# @export setGeneric("takeSample", function(x, withReplacement, num, seed) { standardGeneric("takeSample") }) # @rdname top -# @export setGeneric("top", function(x, num) { standardGeneric("top") }) # @rdname unionRDD -# @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD -# @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD -# @export setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex # @seealso zipWithUniqueId -# @export setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) # @rdname zipWithUniqueId # @seealso zipWithIndex -# @export setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) ############ Binary Functions ############# # @rdname cartesian -# @export setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) # @rdname countByKey -# @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) # @rdname flatMapValues -# @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) # @rdname intersection -# @export setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) # @rdname keys -# @export setGeneric("keys", function(x) { standardGeneric("keys") }) # @rdname lookup -# @export setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) # @rdname mapValues -# @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) # @rdname sampleByKey -# @export setGeneric("sampleByKey", function(x, withReplacement, fractions, seed) { standardGeneric("sampleByKey") }) # @rdname values -# @export setGeneric("values", function(x) { standardGeneric("values") }) @@ -283,14 +237,12 @@ setGeneric("values", function(x) { standardGeneric("values") }) # @rdname aggregateByKey # @seealso foldByKey, combineByKey -# @export setGeneric("aggregateByKey", function(x, zeroValue, seqOp, combOp, numPartitions) { standardGeneric("aggregateByKey") }) # @rdname cogroup -# @export setGeneric("cogroup", function(..., numPartitions) { standardGeneric("cogroup") @@ -299,7 +251,6 @@ setGeneric("cogroup", # @rdname combineByKey # @seealso groupByKey, reduceByKey -# @export setGeneric("combineByKey", function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { standardGeneric("combineByKey") @@ -307,64 +258,53 @@ setGeneric("combineByKey", # @rdname foldByKey # @seealso aggregateByKey, combineByKey -# @export setGeneric("foldByKey", function(x, zeroValue, func, numPartitions) { standardGeneric("foldByKey") }) # @rdname join-methods -# @export setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) # @rdname groupByKey # @seealso reduceByKey -# @export setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) # @rdname join-methods -# @export setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @rdname join-methods -# @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey -# @export setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) # @rdname reduceByKeyLocally # @seealso reduceByKey -# @export setGeneric("reduceByKeyLocally", function(x, combineFunc) { standardGeneric("reduceByKeyLocally") }) # @rdname join-methods -# @export setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) # @rdname sortByKey -# @export setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) # @rdname subtract -# @export setGeneric("subtract", function(x, other, numPartitions = 1) { standardGeneric("subtract") }) # @rdname subtractByKey -# @export setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") @@ -374,7 +314,6 @@ setGeneric("subtractByKey", ################### Broadcast Variable Methods ################# # @rdname broadcast -# @export setGeneric("value", function(bcast) { standardGeneric("value") }) @@ -384,7 +323,6 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @param ... further arguments to be passed to or from other methods. #' @return A SparkDataFrame. #' @rdname summarize -#' @export setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias @@ -399,11 +337,9 @@ setGeneric("agg", function(x, ...) { standardGeneric("agg") }) NULL #' @rdname arrange -#' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname as.data.frame -#' @export setGeneric("as.data.frame", function(x, row.names = NULL, optional = FALSE, ...) { standardGeneric("as.data.frame") @@ -411,52 +347,41 @@ setGeneric("as.data.frame", # Do not document the generic because of signature changes across R versions #' @noRd -#' @export setGeneric("attach") #' @rdname cache -#' @export setGeneric("cache", function(x) { standardGeneric("cache") }) #' @rdname checkpoint -#' @export setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce #' @param x a SparkDataFrame. #' @param ... additional argument(s). -#' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) #' @rdname collect -#' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) #' @param do.NULL currently not used. #' @param prefix currently not used. #' @rdname columns -#' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) #' @rdname columns -#' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) #' @rdname coltypes -#' @export setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) #' @rdname coltypes -#' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) #' @rdname columns -#' @export setGeneric("columns", function(x) {standardGeneric("columns") }) #' @param x a GroupedData or Column. #' @rdname count -#' @export setGeneric("count", function(x) { standardGeneric("count") }) #' @rdname cov @@ -464,7 +389,6 @@ setGeneric("count", function(x) { standardGeneric("count") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname corr @@ -472,294 +396,232 @@ setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname cov -#' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) #' @rdname cov -#' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) #' @rdname createOrReplaceTempView -#' @export setGeneric("createOrReplaceTempView", function(x, viewName) { standardGeneric("createOrReplaceTempView") }) # @rdname crossJoin -# @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) #' @rdname cube -#' @export setGeneric("cube", function(x, ...) { standardGeneric("cube") }) #' @rdname dapply -#' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @rdname dapplyCollect -#' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapply -#' @export setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapplyCollect -#' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) # @rdname getNumPartitions -# @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) #' @rdname describe -#' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname distinct -#' @export setGeneric("distinct", function(x) { standardGeneric("distinct") }) #' @rdname drop -#' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) #' @rdname dropDuplicates -#' @export setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") }) #' @rdname nafunctions -#' @export setGeneric("dropna", function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { standardGeneric("dropna") }) #' @rdname nafunctions -#' @export setGeneric("na.omit", function(object, ...) { standardGeneric("na.omit") }) #' @rdname dtypes -#' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain -#' @export #' @param x a SparkDataFrame or a StreamingQuery. #' @param extended Logical. If extended is FALSE, prints only the physical plan. #' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except -#' @export setGeneric("except", function(x, y) { standardGeneric("except") }) #' @rdname nafunctions -#' @export setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) #' @rdname filter -#' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @rdname first -#' @export setGeneric("first", function(x, ...) { standardGeneric("first") }) #' @rdname groupBy -#' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @rdname groupBy -#' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) #' @rdname hint -#' @export setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) #' @rdname insertInto -#' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) #' @rdname intersect -#' @export setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @rdname isLocal -#' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @rdname isStreaming -#' @export setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @rdname limit -#' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) #' @rdname localCheckpoint -#' @export setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) #' @rdname merge -#' @export setGeneric("merge") #' @rdname mutate -#' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname orderBy -#' @export setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) #' @rdname persist -#' @export setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) #' @rdname printSchema -#' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) #' @rdname registerTempTable-deprecated -#' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) #' @rdname rename -#' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition -#' @export setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) +#' @rdname repartitionByRange +setGeneric("repartitionByRange", function(x, ...) { standardGeneric("repartitionByRange") }) + #' @rdname sample -#' @export setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) #' @rdname rollup -#' @export setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample -#' @export setGeneric("sample_frac", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy -#' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) #' @rdname saveAsTable -#' @export setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) -#' @export setGeneric("str") #' @rdname take -#' @export setGeneric("take", function(x, num) { standardGeneric("take") }) #' @rdname mutate -#' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) #' @rdname write.df -#' @export setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) #' @rdname write.jdbc -#' @export setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { standardGeneric("write.jdbc") }) #' @rdname write.json -#' @export setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc -#' @export setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet -#' @export setGeneric("write.parquet", function(x, path, ...) { standardGeneric("write.parquet") }) #' @rdname write.parquet -#' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) #' @rdname write.stream -#' @export setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { standardGeneric("write.stream") }) #' @rdname write.text -#' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema -#' @export setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select -#' @export setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr -#' @export setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) #' @rdname showDF -#' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname storageLevel -# @export setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) #' @rdname subset -#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname summarize -#' @export setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary -#' @export setGeneric("summary", function(object, ...) { standardGeneric("summary") }) setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -767,824 +629,680 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union -#' @export setGeneric("union", function(x, y) { standardGeneric("union") }) #' @rdname union -#' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @rdname unionByName -#' @export setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) #' @rdname unpersist -#' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @rdname filter -#' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @rdname with -#' @export setGeneric("with") #' @rdname withColumn -#' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) #' @rdname rename -#' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit -#' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) #' @rdname broadcast -#' @export setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) ###################### Column Methods ########################## #' @rdname columnfunctions -#' @export setGeneric("asc", function(x) { standardGeneric("asc") }) #' @rdname between -#' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @rdname cast -#' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @rdname columnfunctions #' @param x a Column object. #' @param ... additional argument(s). -#' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) #' @rdname columnfunctions -#' @export setGeneric("desc", function(x) { standardGeneric("desc") }) #' @rdname endsWith -#' @export setGeneric("endsWith", function(x, suffix) { standardGeneric("endsWith") }) #' @rdname columnfunctions -#' @export setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @rdname columnfunctions -#' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname columnfunctions -#' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname columnfunctions -#' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @rdname columnfunctions -#' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname columnfunctions -#' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname columnfunctions -#' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname startsWith -#' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise -#' @export setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @rdname over -#' @export setGeneric("over", function(x, window) { standardGeneric("over") }) #' @rdname eq_null_safe -#' @export setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) ###################### WindowSpec Methods ########################## #' @rdname partitionBy -#' @export setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) #' @rdname rowsBetween -#' @export setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) #' @rdname rangeBetween -#' @export setGeneric("rangeBetween", function(x, start, end) { standardGeneric("rangeBetween") }) #' @rdname windowPartitionBy -#' @export setGeneric("windowPartitionBy", function(col, ...) { standardGeneric("windowPartitionBy") }) #' @rdname windowOrderBy -#' @export setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy") }) ###################### Expression Function Methods ########################## #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. #' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg -#' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column -#' @export setGeneric("column", function(x) { standardGeneric("column") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_date", function(x = "missing") { standardGeneric("current_date") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_timestamp", function(x = "missing") { standardGeneric("current_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last -#' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count -#' @export setGeneric("n", function(x) { standardGeneric("n") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not -#' @export setGeneric("not", function(x) { standardGeneric("not") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("trim", function(x, trimString) { standardGeneric("trim") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) @@ -1592,142 +1310,110 @@ setGeneric("year", function(x) { standardGeneric("year") }) ###################### Spark.ML Methods ########################## #' @rdname fitted -#' @export setGeneric("fitted") # Do not carry stats::glm usage and param here, and do not document the generic -#' @export #' @noRd setGeneric("glm") #' @param object a fitted ML model object. #' @param ... additional argument(s) passed to the method. #' @rdname predict -#' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind -#' @export setGeneric("rbind", signature = "...") #' @rdname spark.als -#' @export setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) #' @rdname spark.bisectingKmeans -#' @export setGeneric("spark.bisectingKmeans", function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") }) #' @rdname spark.gaussianMixture -#' @export setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) #' @rdname spark.gbt -#' @export setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) #' @rdname spark.glm -#' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) #' @rdname spark.isoreg -#' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) #' @rdname spark.kmeans -#' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) #' @rdname spark.kstest -#' @export setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) #' @rdname spark.lda -#' @export setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) #' @rdname spark.logit -#' @export setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp -#' @export setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes -#' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) #' @rdname spark.decisionTree -#' @export setGeneric("spark.decisionTree", function(data, formula, ...) { standardGeneric("spark.decisionTree") }) #' @rdname spark.randomForest -#' @export setGeneric("spark.randomForest", function(data, formula, ...) { standardGeneric("spark.randomForest") }) #' @rdname spark.survreg -#' @export setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear -#' @export setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") }) #' @rdname spark.lda -#' @export setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") }) #' @rdname spark.lda -#' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml -#' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) ###################### Streaming Methods ########################## #' @rdname awaitTermination -#' @export setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive -#' @export setGeneric("isActive", function(x) { standardGeneric("isActive") }) #' @rdname lastProgress -#' @export setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) #' @rdname queryName -#' @export setGeneric("queryName", function(x) { standardGeneric("queryName") }) #' @rdname status -#' @export setGeneric("status", function(x) { standardGeneric("status") }) #' @rdname stopQuery -#' @export setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 54ef9f07d6fa..f751b952f391 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -30,7 +30,6 @@ setOldClass("jobj") #' @seealso groupBy #' #' @param sgd A Java object reference to the backing Scala GroupedData -#' @export #' @note GroupedData since 1.4.0 setClass("GroupedData", slots = list(sgd = "jobj")) @@ -48,7 +47,6 @@ groupedData <- function(sgd) { #' @rdname show #' @aliases show,GroupedData-method -#' @export #' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "GroupedData", #' @return A SparkDataFrame. #' @rdname count #' @aliases count,GroupedData-method -#' @export #' @examples #' \dontrun{ #' count(groupBy(df, "name")) @@ -87,7 +84,6 @@ setMethod("count", #' @aliases agg,GroupedData-method #' @name agg #' @family agg_funcs -#' @export #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -150,7 +146,6 @@ methods <- c("avg", "max", "mean", "min", "sum") #' @rdname pivot #' @aliases pivot,GroupedData,character-method #' @name pivot -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -202,7 +197,6 @@ createMethods() #' @rdname gapply #' @aliases gapply,GroupedData-method #' @name gapply -#' @export #' @note gapply(GroupedData) since 2.0.0 setMethod("gapply", signature(x = "GroupedData"), @@ -216,7 +210,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @aliases gapplyCollect,GroupedData-method #' @name gapplyCollect -#' @export #' @note gapplyCollect(GroupedData) since 2.0.0 setMethod("gapplyCollect", signature(x = "GroupedData"), diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 04dc7562e534..6d1edf6b6f3c 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -58,7 +58,6 @@ #' @rdname install.spark #' @name install.spark #' @aliases install.spark -#' @export #' @examples #'\dontrun{ #' install.spark() diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R index bb5c77544a3d..9a1b26b0fa3c 100644 --- a/R/pkg/R/jvm.R +++ b/R/pkg/R/jvm.R @@ -35,7 +35,6 @@ #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} #' @rdname sparkR.callJMethod #' @examples @@ -69,7 +68,6 @@ sparkR.callJMethod <- function(x, methodName, ...) { #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} #' @rdname sparkR.callJStatic #' @examples @@ -100,7 +98,6 @@ sparkR.callJStatic <- function(x, methodName, ...) { #' @param ... arguments to be passed to the constructor. #' @return the object created. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} #' @rdname sparkR.newJObject #' @examples diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 7cd072a1d6f8..2964fdeff095 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -21,28 +21,24 @@ #' S4 class that represents an LinearSVCModel #' #' @param jobj a Java object reference to the backing Scala LinearSVCModel -#' @export #' @note LinearSVCModel since 2.2.0 setClass("LinearSVCModel", representation(jobj = "jobj")) #' S4 class that represents an LogisticRegressionModel #' #' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a MultilayerPerceptronClassificationModel #' #' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export #' @note MultilayerPerceptronClassificationModel since 2.1.0 setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a NaiveBayesModel #' #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) @@ -82,7 +78,6 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @rdname spark.svmLinear #' @aliases spark.svmLinear,SparkDataFrame,formula-method #' @name spark.svmLinear -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -131,7 +126,6 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu #' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method -#' @export #' @note predict(LinearSVCModel) since 2.2.0 setMethod("predict", signature(object = "LinearSVCModel"), function(object, newData) { @@ -146,7 +140,6 @@ setMethod("predict", signature(object = "LinearSVCModel"), #' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method -#' @export #' @note summary(LinearSVCModel) since 2.2.0 setMethod("summary", signature(object = "LinearSVCModel"), function(object) { @@ -169,7 +162,6 @@ setMethod("summary", signature(object = "LinearSVCModel"), #' #' @rdname spark.svmLinear #' @aliases write.ml,LinearSVCModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.2.0 setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -257,7 +249,6 @@ function(object, path, overwrite = FALSE) { #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -279,11 +270,24 @@ function(object, path, overwrite = FALSE) { #' savedModel <- read.ml(path) #' summary(savedModel) #' -#' # multinomial logistic regression +#' # binary logistic regression against two classes with +#' # upperBoundsOnCoefficients and upperBoundsOnIntercepts +#' ubc <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) +#' model <- spark.logit(training, Species ~ ., +#' upperBoundsOnCoefficients = ubc, +#' upperBoundsOnIntercepts = 1.0) #' +#' # multinomial logistic regression #' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' +#' # multinomial logistic regression with +#' # lowerBoundsOnCoefficients and lowerBoundsOnIntercepts +#' lbc <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) +#' lbi <- as.array(c(0.0, 0.0)) +#' model <- spark.logit(training, Species ~ ., family = "multinomial", +#' lowerBoundsOnCoefficients = lbc, +#' lowerBoundsOnIntercepts = lbi) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), @@ -361,7 +365,6 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") #' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method -#' @export #' @note summary(LogisticRegressionModel) since 2.1.0 setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { @@ -389,7 +392,6 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. #' @rdname spark.logit #' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export #' @note predict(LogisticRegressionModel) since 2.1.0 setMethod("predict", signature(object = "LogisticRegressionModel"), function(object, newData) { @@ -404,7 +406,6 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), #' #' @rdname spark.logit #' @aliases write.ml,LogisticRegressionModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -445,7 +446,6 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} -#' @export #' @examples #' \dontrun{ #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") @@ -504,7 +504,6 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), #' For \code{weights}, it is a numeric vector with length equal to the expected #' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp -#' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method #' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), @@ -525,7 +524,6 @@ setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel #' "prediction". #' @rdname spark.mlp #' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export #' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), function(object, newData) { @@ -540,7 +538,6 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel #' #' @rdname spark.mlp #' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export #' @seealso \link{write.ml} #' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", @@ -572,7 +569,6 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @aliases spark.naiveBayes,SparkDataFrame,formula-method #' @name spark.naiveBayes #' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export #' @examples #' \dontrun{ #' data <- as.data.frame(UCBAdmissions) @@ -611,7 +607,6 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form #' The list includes \code{apriori} (the label distribution) and #' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes -#' @export #' @note summary(NaiveBayesModel) since 2.0.0 setMethod("summary", signature(object = "NaiveBayesModel"), function(object) { @@ -635,7 +630,6 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named #' "prediction". #' @rdname spark.naiveBayes -#' @export #' @note predict(NaiveBayesModel) since 2.0.0 setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { @@ -649,7 +643,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.naiveBayes -#' @export #' @seealso \link{write.ml} #' @note write.ml(NaiveBayesModel, character) since 2.0.0 setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index a25bf81c6d97..900be685824d 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -20,28 +20,24 @@ #' S4 class that represents a BisectingKMeansModel #' #' @param jobj a Java object reference to the backing Scala BisectingKMeansModel -#' @export #' @note BisectingKMeansModel since 2.2.0 setClass("BisectingKMeansModel", representation(jobj = "jobj")) #' S4 class that represents a GaussianMixtureModel #' #' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export #' @note GaussianMixtureModel since 2.1.0 setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' S4 class that represents a KMeansModel #' #' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export #' @note KMeansModel since 2.0.0 setClass("KMeansModel", representation(jobj = "jobj")) #' S4 class that represents an LDAModel #' #' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export #' @note LDAModel since 2.1.0 setClass("LDAModel", representation(jobj = "jobj")) @@ -68,7 +64,6 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @rdname spark.bisectingKmeans #' @aliases spark.bisectingKmeans,SparkDataFrame,formula-method #' @name spark.bisectingKmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -117,7 +112,6 @@ setMethod("spark.bisectingKmeans", signature(data = "SparkDataFrame", formula = #' (cluster centers of the transformed data; cluster is NULL if is.loaded is TRUE), #' and \code{is.loaded} (whether the model is loaded from a saved file). #' @rdname spark.bisectingKmeans -#' @export #' @note summary(BisectingKMeansModel) since 2.2.0 setMethod("summary", signature(object = "BisectingKMeansModel"), function(object) { @@ -144,7 +138,6 @@ setMethod("summary", signature(object = "BisectingKMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a bisecting k-means model. #' @rdname spark.bisectingKmeans -#' @export #' @note predict(BisectingKMeansModel) since 2.2.0 setMethod("predict", signature(object = "BisectingKMeansModel"), function(object, newData) { @@ -160,7 +153,6 @@ setMethod("predict", signature(object = "BisectingKMeansModel"), #' or \code{"classes"} for assigned classes. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname spark.bisectingKmeans -#' @export #' @note fitted since 2.2.0 setMethod("fitted", signature(object = "BisectingKMeansModel"), function(object, method = c("centers", "classes")) { @@ -181,7 +173,6 @@ setMethod("fitted", signature(object = "BisectingKMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.bisectingKmeans -#' @export #' @note write.ml(BisectingKMeansModel, character) since 2.2.0 setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -208,7 +199,6 @@ setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "charact #' @rdname spark.gaussianMixture #' @name spark.gaussianMixture #' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +241,6 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = #' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture -#' @export #' @note summary(GaussianMixtureModel) since 2.1.0 setMethod("summary", signature(object = "GaussianMixtureModel"), function(object) { @@ -291,7 +280,6 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), #' "prediction". #' @aliases predict,GaussianMixtureModel,SparkDataFrame-method #' @rdname spark.gaussianMixture -#' @export #' @note predict(GaussianMixtureModel) since 2.1.0 setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { @@ -306,7 +294,6 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' @aliases write.ml,GaussianMixtureModel,character-method #' @rdname spark.gaussianMixture -#' @export #' @note write.ml(GaussianMixtureModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -336,7 +323,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @rdname spark.kmeans #' @aliases spark.kmeans,SparkDataFrame,formula-method #' @name spark.kmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -385,7 +371,6 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' (the actual number of cluster centers. When using initMode = "random", #' \code{clusterSize} may not equal to \code{k}). #' @rdname spark.kmeans -#' @export #' @note summary(KMeansModel) since 2.0.0 setMethod("summary", signature(object = "KMeansModel"), function(object) { @@ -413,7 +398,6 @@ setMethod("summary", signature(object = "KMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a k-means model. #' @rdname spark.kmeans -#' @export #' @note predict(KMeansModel) since 2.0.0 setMethod("predict", signature(object = "KMeansModel"), function(object, newData) { @@ -431,7 +415,6 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param ... additional argument(s) passed to the method. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted -#' @export #' @examples #' \dontrun{ #' model <- spark.kmeans(trainingData, ~ ., 2) @@ -458,7 +441,6 @@ setMethod("fitted", signature(object = "KMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.kmeans -#' @export #' @note write.ml(KMeansModel, character) since 2.0.0 setMethod("write.ml", signature(object = "KMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -496,7 +478,6 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export #' @examples #' \dontrun{ #' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") @@ -558,7 +539,6 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' It is only for distributed LDA model (i.e., optimizer = "em")} #' @rdname spark.lda #' @aliases summary,LDAModel-method -#' @export #' @note summary(LDAModel) since 2.1.0 setMethod("summary", signature(object = "LDAModel"), function(object, maxTermsPerTopic) { @@ -596,7 +576,6 @@ setMethod("summary", signature(object = "LDAModel"), #' perplexity of the training data if missing argument "data". #' @rdname spark.lda #' @aliases spark.perplexity,LDAModel-method -#' @export #' @note spark.perplexity(LDAModel) since 2.1.0 setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), function(object, data) { @@ -611,7 +590,6 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr #' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export #' @note spark.posterior(LDAModel) since 2.1.0 setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), function(object, newData) { @@ -626,7 +604,6 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' #' @rdname spark.lda #' @aliases write.ml,LDAModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(LDAModel, character) since 2.1.0 setMethod("write.ml", signature(object = "LDAModel", path = "character"), diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index dfcb45a1b66c..e2394906d801 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -20,7 +20,6 @@ #' S4 class that represents a FPGrowthModel #' #' @param jobj a Java object reference to the backing Scala FPGrowthModel -#' @export #' @note FPGrowthModel since 2.2.0 setClass("FPGrowthModel", slots = list(jobj = "jobj")) @@ -45,7 +44,6 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' @rdname spark.fpGrowth #' @name spark.fpGrowth #' @aliases spark.fpGrowth,SparkDataFrame-method -#' @export #' @examples #' \dontrun{ #' raw_data <- read.df( @@ -109,7 +107,6 @@ setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), #' and \code{freq} (frequency of the itemset). #' @rdname spark.fpGrowth #' @aliases freqItemsets,FPGrowthModel-method -#' @export #' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), function(object) { @@ -125,7 +122,6 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), #' and \code{condfidence} (confidence). #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method -#' @export #' @note spark.associationRules(FPGrowthModel) since 2.2.0 setMethod("spark.associationRules", signature(object = "FPGrowthModel"), function(object) { @@ -138,7 +134,6 @@ setMethod("spark.associationRules", signature(object = "FPGrowthModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.fpGrowth #' @aliases predict,FPGrowthModel-method -#' @export #' @note predict(FPGrowthModel) since 2.2.0 setMethod("predict", signature(object = "FPGrowthModel"), function(object, newData) { @@ -153,7 +148,6 @@ setMethod("predict", signature(object = "FPGrowthModel"), #' if the output path exists. #' @rdname spark.fpGrowth #' @aliases write.ml,FPGrowthModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(FPGrowthModel, character) since 2.2.0 setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index fa794249085d..9a77b0746258 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -20,7 +20,6 @@ #' S4 class that represents an ALSModel #' #' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export #' @note ALSModel since 2.1.0 setClass("ALSModel", representation(jobj = "jobj")) @@ -48,12 +47,13 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @param numUserBlocks number of user blocks used to parallelize computation (> 0). #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als -#' @export #' @examples #' \dontrun{ #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -116,7 +116,6 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), #' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method -#' @export #' @note summary(ALSModel) since 2.1.0 setMethod("summary", signature(object = "ALSModel"), function(object) { @@ -137,7 +136,6 @@ setMethod("summary", signature(object = "ALSModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.als #' @aliases predict,ALSModel-method -#' @export #' @note predict(ALSModel) since 2.1.0 setMethod("predict", signature(object = "ALSModel"), function(object, newData) { @@ -153,7 +151,6 @@ setMethod("predict", signature(object = "ALSModel"), #' #' @rdname spark.als #' @aliases write.ml,ALSModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(ALSModel, character) since 2.1.0 setMethod("write.ml", signature(object = "ALSModel", path = "character"), diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 545be5e1d89f..95c1a2990519 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -21,21 +21,18 @@ #' S4 class that represents a AFTSurvivalRegressionModel #' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export #' @note AFTSurvivalRegressionModel since 2.0.0 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a generalized linear model #' #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export #' @note GeneralizedLinearRegressionModel since 2.0.0 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' S4 class that represents an IsotonicRegressionModel #' #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export #' @note IsotonicRegressionModel since 2.1.0 setClass("IsotonicRegressionModel", representation(jobj = "jobj")) @@ -85,7 +82,6 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -211,7 +207,6 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @aliases glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -244,7 +239,6 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in #' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm -#' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object) { @@ -290,7 +284,6 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), #' @rdname spark.glm #' @param x summary object of fitted generalized linear model returned by \code{summary} function. -#' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { if (x$is.loaded) { @@ -324,7 +317,6 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named #' "prediction". #' @rdname spark.glm -#' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { @@ -338,7 +330,6 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.glm -#' @export #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -363,7 +354,6 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -412,7 +402,6 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" #' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method -#' @export #' @note summary(IsotonicRegressionModel) since 2.1.0 setMethod("summary", signature(object = "IsotonicRegressionModel"), function(object) { @@ -429,7 +418,6 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export #' @note predict(IsotonicRegressionModel) since 2.1.0 setMethod("predict", signature(object = "IsotonicRegressionModel"), function(object, newData) { @@ -444,7 +432,6 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), #' #' @rdname spark.isoreg #' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export #' @note write.ml(IsotonicRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -477,7 +464,6 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(ovarian) @@ -517,7 +503,6 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' The list includes the model's \code{coefficients} (features, coefficients, #' intercept and log(scale)). #' @rdname spark.survreg -#' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), function(object) { @@ -537,7 +522,6 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values #' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg -#' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { @@ -550,7 +534,6 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @rdname spark.survreg -#' @export #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 #' @seealso \link{write.ml} setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R index 3e013f1d45e3..f8c332935996 100644 --- a/R/pkg/R/mllib_stat.R +++ b/R/pkg/R/mllib_stat.R @@ -20,7 +20,6 @@ #' S4 class that represents an KSTest #' #' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) @@ -52,7 +51,6 @@ setClass("KSTest", representation(jobj = "jobj")) #' @name spark.kstest #' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ #' MLlib: Hypothesis Testing} -#' @export #' @examples #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) @@ -94,7 +92,6 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), #' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method -#' @export #' @note summary(KSTest) since 2.1.0 setMethod("summary", signature(object = "KSTest"), function(object) { @@ -117,7 +114,6 @@ setMethod("summary", signature(object = "KSTest"), #' @rdname spark.kstest #' @param x summary object of KSTest returned by \code{summary}. -#' @export #' @note print.summary.KSTest since 2.1.0 print.summary.KSTest <- function(x, ...) { jobj <- x$jobj diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 89a58bf0aada..6769be038efa 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -20,42 +20,36 @@ #' S4 class that represents a GBTRegressionModel #' #' @param jobj a Java object reference to the backing Scala GBTRegressionModel -#' @export #' @note GBTRegressionModel since 2.1.0 setClass("GBTRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a GBTClassificationModel #' #' @param jobj a Java object reference to the backing Scala GBTClassificationModel -#' @export #' @note GBTClassificationModel since 2.1.0 setClass("GBTClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestRegressionModel #' #' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export #' @note RandomForestRegressionModel since 2.1.0 setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestClassificationModel #' #' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeRegressionModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel -#' @export #' @note DecisionTreeRegressionModel since 2.3.0 setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeClassificationModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel -#' @export #' @note DecisionTreeClassificationModel since 2.3.0 setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) @@ -161,6 +155,8 @@ print.summary.decisionTree <- function(x) { #' >= 1. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -177,7 +173,6 @@ print.summary.decisionTree <- function(x) { #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. #' @rdname spark.gbt #' @name spark.gbt -#' @export #' @examples #' \dontrun{ #' # fit a Gradient Boosted Tree Regression Model @@ -259,7 +254,6 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method -#' @export #' @note summary(GBTRegressionModel) since 2.1.0 setMethod("summary", signature(object = "GBTRegressionModel"), function(object) { @@ -273,7 +267,6 @@ setMethod("summary", signature(object = "GBTRegressionModel"), #' @param x summary object of Gradient Boosted Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.gbt -#' @export #' @note print.summary.GBTRegressionModel since 2.1.0 print.summary.GBTRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -283,7 +276,6 @@ print.summary.GBTRegressionModel <- function(x, ...) { #' @rdname spark.gbt #' @aliases summary,GBTClassificationModel-method -#' @export #' @note summary(GBTClassificationModel) since 2.1.0 setMethod("summary", signature(object = "GBTClassificationModel"), function(object) { @@ -295,7 +287,6 @@ setMethod("summary", signature(object = "GBTClassificationModel"), # Prints the summary of Gradient Boosted Tree Classification Model #' @rdname spark.gbt -#' @export #' @note print.summary.GBTClassificationModel since 2.1.0 print.summary.GBTClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -308,7 +299,6 @@ print.summary.GBTClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.gbt #' @aliases predict,GBTRegressionModel-method -#' @export #' @note predict(GBTRegressionModel) since 2.1.0 setMethod("predict", signature(object = "GBTRegressionModel"), function(object, newData) { @@ -317,7 +307,6 @@ setMethod("predict", signature(object = "GBTRegressionModel"), #' @rdname spark.gbt #' @aliases predict,GBTClassificationModel-method -#' @export #' @note predict(GBTClassificationModel) since 2.1.0 setMethod("predict", signature(object = "GBTClassificationModel"), function(object, newData) { @@ -332,7 +321,6 @@ setMethod("predict", signature(object = "GBTClassificationModel"), #' which means throw exception if the output path exists. #' @aliases write.ml,GBTRegressionModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -341,7 +329,6 @@ setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character #' @aliases write.ml,GBTClassificationModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -382,6 +369,8 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -398,7 +387,6 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @return \code{spark.randomForest} returns a fitted Random Forest model. #' @rdname spark.randomForest #' @name spark.randomForest -#' @export #' @examples #' \dontrun{ #' # fit a Random Forest Regression Model @@ -476,7 +464,6 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method -#' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { @@ -490,7 +477,6 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -500,7 +486,6 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method -#' @export #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { @@ -512,7 +497,6 @@ setMethod("summary", signature(object = "RandomForestClassificationModel"), # Prints the summary of Random Forest Classification Model #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -525,7 +509,6 @@ print.summary.RandomForestClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method -#' @export #' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { @@ -534,7 +517,6 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method -#' @export #' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { @@ -550,7 +532,6 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), #' #' @aliases write.ml,RandomForestRegressionModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -559,7 +540,6 @@ setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = " #' @aliases write.ml,RandomForestClassificationModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -595,6 +575,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -611,7 +593,6 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. #' @rdname spark.decisionTree #' @name spark.decisionTree -#' @export #' @examples #' \dontrun{ #' # fit a Decision Tree Regression Model @@ -684,7 +665,6 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method -#' @export #' @note summary(DecisionTreeRegressionModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeRegressionModel"), function(object) { @@ -698,7 +678,6 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"), #' @param x summary object of Decision Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeRegressionModel since 2.3.0 print.summary.DecisionTreeRegressionModel <- function(x, ...) { print.summary.decisionTree(x) @@ -708,7 +687,6 @@ print.summary.DecisionTreeRegressionModel <- function(x, ...) { #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeClassificationModel-method -#' @export #' @note summary(DecisionTreeClassificationModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeClassificationModel"), function(object) { @@ -720,7 +698,6 @@ setMethod("summary", signature(object = "DecisionTreeClassificationModel"), # Prints the summary of Decision Tree Classification Model #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeClassificationModel since 2.3.0 print.summary.DecisionTreeClassificationModel <- function(x, ...) { print.summary.decisionTree(x) @@ -733,7 +710,6 @@ print.summary.DecisionTreeClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeRegressionModel-method -#' @export #' @note predict(DecisionTreeRegressionModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeRegressionModel"), function(object, newData) { @@ -742,7 +718,6 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"), #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeClassificationModel-method -#' @export #' @note predict(DecisionTreeClassificationModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeClassificationModel"), function(object, newData) { @@ -758,7 +733,6 @@ setMethod("predict", signature(object = "DecisionTreeClassificationModel"), #' #' @aliases write.ml,DecisionTreeRegressionModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -767,7 +741,6 @@ setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = " #' @aliases write.ml,DecisionTreeClassificationModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c481..7d04bffcba3a 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -31,7 +31,6 @@ #' MLlib model below. #' @rdname write.ml #' @name write.ml -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -48,7 +47,6 @@ NULL #' MLlib model below. #' @rdname predict #' @name predict -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -75,7 +73,6 @@ predict_internal <- function(object, newData) { #' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml -#' @export #' @seealso \link{write.ml} #' @examples #' \dontrun{ diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 65f418740c64..9831fc3cc6d0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -29,7 +29,6 @@ #' @param ... additional structField objects #' @return a structType object #' @rdname structType -#' @export #' @examples #'\dontrun{ #' schema <- structType(structField("a", "integer"), structField("c", "string"), @@ -49,7 +48,6 @@ structType <- function(x, ...) { #' @rdname structType #' @method structType jobj -#' @export structType.jobj <- function(x, ...) { obj <- structure(list(), class = "structType") obj$jobj <- x @@ -59,7 +57,6 @@ structType.jobj <- function(x, ...) { #' @rdname structType #' @method structType structField -#' @export structType.structField <- function(x, ...) { fields <- list(x, ...) if (!all(sapply(fields, inherits, "structField"))) { @@ -76,7 +73,6 @@ structType.structField <- function(x, ...) { #' @rdname structType #' @method structType character -#' @export structType.character <- function(x, ...) { if (!is.character(x)) { stop("schema must be a DDL-formatted string.") @@ -119,7 +115,6 @@ print.structType <- function(x, ...) { #' @param ... additional argument(s) passed to the method. #' @return A structField object. #' @rdname structField -#' @export #' @examples #'\dontrun{ #' field1 <- structField("a", "integer") @@ -137,7 +132,6 @@ structField <- function(x, ...) { #' @rdname structField #' @method structField jobj -#' @export structField.jobj <- function(x, ...) { obj <- structure(list(), class = "structField") obj$jobj <- x @@ -212,7 +206,6 @@ checkType <- function(type) { #' @param type The data type of the field #' @param nullable A logical vector indicating whether or not the field is nullable #' @rdname structField -#' @export structField.character <- function(x, type, nullable = TRUE, ...) { if (class(x) != "character") { stop("Field name must be a string.") diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 3bbf60d9b668..263b9b576c0c 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -30,14 +30,17 @@ # POSIXct,POSIXlt -> Time # # list[T] -> Array[T], where T is one of above mentioned types +# Multi-element vector of any of the above (except raw) -> Array[T] # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend # nolint end getSerdeType <- function(object) { type <- class(object)[[1]] - if (type != "list") { - type + if (is.atomic(object) & !is.raw(object) & length(object) > 1) { + "array" + } else if (type != "list") { + type } else { # Check if all elements are of same type elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) @@ -50,9 +53,7 @@ getSerdeType <- function(object) { } writeObject <- function(con, object, writeType = TRUE) { - # NOTE: In R vectors have same type as objects. So we don't support - # passing in vectors as arrays and instead require arrays to be passed - # as lists. + # NOTE: In R vectors have same type as objects type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") # Checking types is needed here, since 'is.na' only handles atomic vectors, # lists and pairlists diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 965471f3b07a..a480ac606f10 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -35,7 +35,6 @@ connExists <- function(env) { #' Also terminates the backend this R session is connected to. #' @rdname sparkR.session.stop #' @name sparkR.session.stop -#' @export #' @note sparkR.session.stop since 2.0.0 sparkR.session.stop <- function() { env <- .sparkREnv @@ -84,7 +83,6 @@ sparkR.session.stop <- function() { #' @rdname sparkR.session.stop #' @name sparkR.stop -#' @export #' @note sparkR.stop since 1.4.0 sparkR.stop <- function() { sparkR.session.stop() @@ -103,7 +101,6 @@ sparkR.stop <- function() { #' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") @@ -270,7 +267,6 @@ sparkR.sparkContext <- function( #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRSQL.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -298,7 +294,6 @@ sparkRSQL.init <- function(jsc = NULL) { #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRHive.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -347,7 +342,6 @@ sparkRHive.init <- function(jsc = NULL) { #' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session #' @param ... named Spark properties passed to the method. -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -442,7 +436,6 @@ sparkR.session <- function( #' @return the SparkUI URL, or NA if it is disabled, or not started. #' @rdname sparkR.uiWebUrl #' @name sparkR.uiWebUrl -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index c8af798830b3..497f18c76304 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -37,7 +37,6 @@ setOldClass("jobj") #' @name crosstab #' @aliases crosstab,SparkDataFrame,character,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -63,7 +62,6 @@ setMethod("crosstab", #' @rdname cov #' @aliases cov,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -92,7 +90,6 @@ setMethod("cov", #' @name corr #' @aliases corr,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -124,7 +121,6 @@ setMethod("corr", #' @name freqItems #' @aliases freqItems,SparkDataFrame,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -168,7 +164,6 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @name approxQuantile #' @aliases approxQuantile,SparkDataFrame,character,numeric,numeric-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -205,7 +200,6 @@ setMethod("approxQuantile", #' @aliases sampleBy,SparkDataFrame,character,list,numeric-method #' @name sampleBy #' @family stat functions -#' @export #' @examples #'\dontrun{ #' df <- read.json("/path/to/file.json") diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index 8390bd5e6de7..fc83463f72cd 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{read.stream} #' #' @param ssq A Java object reference to the backing Scala StreamingQuery -#' @export #' @note StreamingQuery since 2.2.0 #' @note experimental setClass("StreamingQuery", @@ -45,7 +44,6 @@ streamingQuery <- function(ssq) { } #' @rdname show -#' @export #' @note show(StreamingQuery) since 2.2.0 setMethod("show", "StreamingQuery", function(object) { @@ -70,7 +68,6 @@ setMethod("show", "StreamingQuery", #' @aliases queryName,StreamingQuery-method #' @family StreamingQuery methods #' @seealso \link{write.stream} -#' @export #' @examples #' \dontrun{ queryName(sq) } #' @note queryName(StreamingQuery) since 2.2.0 @@ -85,7 +82,6 @@ setMethod("queryName", #' @name explain #' @aliases explain,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ explain(sq) } #' @note explain(StreamingQuery) since 2.2.0 @@ -104,7 +100,6 @@ setMethod("explain", #' @name lastProgress #' @aliases lastProgress,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ lastProgress(sq) } #' @note lastProgress(StreamingQuery) since 2.2.0 @@ -129,7 +124,6 @@ setMethod("lastProgress", #' @name status #' @aliases status,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ status(sq) } #' @note status(StreamingQuery) since 2.2.0 @@ -150,7 +144,6 @@ setMethod("status", #' @name isActive #' @aliases isActive,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ isActive(sq) } #' @note isActive(StreamingQuery) since 2.2.0 @@ -177,7 +170,6 @@ setMethod("isActive", #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ awaitTermination(sq, 10000) } #' @note awaitTermination(StreamingQuery) since 2.2.0 @@ -202,7 +194,6 @@ setMethod("awaitTermination", #' @name stopQuery #' @aliases stopQuery,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ stopQuery(sq) } #' @note stopQuery(StreamingQuery) since 2.2.0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 164cd6d01a34..f1b5ecaa017d 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -108,7 +108,6 @@ isRDD <- function(name, env) { #' #' @param key the object to be hashed #' @return the hash code as an integer -#' @export #' @examples #'\dontrun{ #' hashCode(1L) # 1 diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R index 0799d841e5dc..396b27bee80c 100644 --- a/R/pkg/R/window.R +++ b/R/pkg/R/window.R @@ -29,7 +29,6 @@ #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") @@ -52,7 +51,6 @@ setMethod("windowPartitionBy", #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,Column-method -#' @export #' @note windowPartitionBy(Column) since 2.0.0 setMethod("windowPartitionBy", signature(col = "Column"), @@ -78,7 +76,6 @@ setMethod("windowPartitionBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- windowOrderBy("key1", "key2") @@ -101,7 +98,6 @@ setMethod("windowOrderBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,Column-method -#' @export #' @note windowOrderBy(Column) since 2.0.0 setMethod("windowOrderBy", signature(col = "Column"), diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 6bbd201bf1d8..3577929323b8 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -37,6 +37,53 @@ test_that("SerDe of primitive types", { expect_equal(class(x), "character") }) +test_that("SerDe of multi-element primitive vectors inside R data.frame", { + # vector of integers embedded in R data.frame + indices <- 1L:3L + myDf <- data.frame(indices) + myDf$data <- list(rep(0L, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0L, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "integer") + + # vector of numeric embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(0, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "numeric") + + # vector of logical embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(TRUE, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(TRUE, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "logical") + + # vector of character embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep("abc", 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep("abc", 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "character") +}) + test_that("SerDe of list of primitive types", { x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index ad47717ddc12..a46c47dccd02 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -124,7 +124,7 @@ test_that("spark.logit", { # Petal.Width 0.42122607 # nolint end - # Test multinomial logistic regression againt three classes + # Test multinomial logistic regression against three classes df <- suppressWarnings(createDataFrame(iris)) model <- spark.logit(df, Species ~ ., regParam = 0.5) summary <- summary(model) @@ -196,7 +196,7 @@ test_that("spark.logit", { # # nolint end - # Test multinomial logistic regression againt two classes + # Test multinomial logistic regression against two classes df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") @@ -208,7 +208,7 @@ test_that("spark.logit", { expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) - # Test binomial logistic regression againt two classes + # Test binomial logistic regression against two classes model <- spark.logit(training, Species ~ ., regParam = 0.5) summary <- summary(model) coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) @@ -239,7 +239,7 @@ test_that("spark.logit", { prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) - # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # Test binomial logistic regression against two classes with upperBoundsOnCoefficients # and upperBoundsOnIntercepts u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, @@ -252,7 +252,7 @@ test_that("spark.logit", { expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), upperBoundsOnIntercepts = 1.0)) - # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # Test binomial logistic regression against two classes with lowerBoundsOnCoefficients # and lowerBoundsOnIntercepts l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5197838eaac6..43725e0ebd3b 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -67,6 +67,8 @@ sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) +# materialize the catalog implementation +listTables() mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -1477,17 +1479,36 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() + # and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - # Test map_keys() and map_values() + # Test flattern + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) @@ -1495,6 +1516,9 @@ test_that("column functions", { result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) @@ -1649,6 +1673,7 @@ test_that("string operators", { expect_false(first(select(df, startsWith(df$name, "m")))[[1]]) expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(first(select(df, substr(df$name, 4, 6)))[[1]], "hae") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { expect_true(startsWith("Hello World", "Hello")) expect_false(endsWith("Hello World", "a")) @@ -2185,8 +2210,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -3101,6 +3126,51 @@ test_that("repartition by columns on DataFrame", { }) }) +test_that("repartitionByRange on a DataFrame", { + # The tasks here launch R workers with shuffles. So, we decrease the number of shuffle + # partitions to reduce the number of the tasks to speed up the test. This is particularly + # slow on Windows because the R workers are unable to be forked. See also SPARK-21693. + conf <- callJMethod(sparkSession, "conf") + shufflepartitionsvalue <- callJMethod(conf, "get", "spark.sql.shuffle.partitions") + callJMethod(conf, "set", "spark.sql.shuffle.partitions", "5") + tryCatch({ + df <- createDataFrame(mtcars) + expect_error(repartitionByRange(df, "haha", df$mpg), + "numPartitions and col must be numeric and Column.*") + expect_error(repartitionByRange(df), + ".*specify a column.*or the number of partitions with a column.*") + expect_error(repartitionByRange(df, col = "haha"), + "col must be Column; however, got.*") + expect_error(repartitionByRange(df, 3), + "At least one partition-by column must be specified.") + + # The order of rows should be different with a normal repartition. + actual <- repartitionByRange(df, 3, df$mpg) + expect_equal(getNumPartitions(actual), 3) + expect_false(identical(collect(actual), collect(repartition(df, 3, df$mpg)))) + + actual <- repartitionByRange(df, col = df$mpg) + expect_false(identical(collect(actual), collect(repartition(df, col = df$mpg)))) + + # They should have same data. + actual <- collect(repartitionByRange(df, 3, df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, 3, df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + + actual <- collect(repartitionByRange(df, col = df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, col = df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.shuffle.partitions", shufflepartitionsvalue) + }) +}) + test_that("coalesce, repartition, numPartitions", { df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f51..bfb1a046490e 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime), ][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 2e662424b25f..d4713de7806a 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,7 +46,7 @@ Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ## Overview -SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](https://spark.apache.org/mllib/). ## Getting Started @@ -132,7 +132,7 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](https://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() @@ -147,7 +147,7 @@ sparkR.session(sparkHome = "/HOME/spark") ### Spark Session {#SetupSparkSession} -In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](https://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](https://spark.apache.org/docs/latest/api/R/sparkR.session.html). In particular, the following Spark driver properties can be set in `sparkConfig`. @@ -169,7 +169,7 @@ sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) #### Cluster Mode -SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with ```{r, echo=FALSE, tidy = TRUE} @@ -177,7 +177,7 @@ paste("Spark", packageVersion("SparkR")) ``` It should be used both on the local computer and on the remote cluster. -To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](https://spark.apache.org/docs/latest/submitting-applications.html#master-urls). For example, to connect to a local standalone Spark master, we can call ```{r, eval=FALSE} @@ -317,7 +317,7 @@ A common flow of grouping and aggregation is 2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. -A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there. For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. @@ -935,7 +935,7 @@ perplexity #### Alternating Least Squares -`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). +`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](https://dl.acm.org/citation.cfm?id=1608614). There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. @@ -1042,7 +1042,7 @@ unlink(modelPath) ## Structured Streaming -SparkR supports the Structured Streaming API (experimental). +SparkR supports the Structured Streaming API. You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. @@ -1171,11 +1171,11 @@ env | map ## References -* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) +* [Spark Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) -* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) +* [Submitting Spark Applications](https://spark.apache.org/docs/latest/submitting-applications.html) -* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) +* [Machine Learning Library Guide (MLlib)](https://spark.apache.org/docs/latest/ml-guide.html) * [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. diff --git a/assembly/pom.xml b/assembly/pom.xml index b3b4239771bc..9608c96fd536 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh new file mode 100755 index 000000000000..f090240065bf --- /dev/null +++ b/bin/docker-image-tool.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script builds and pushes docker images when run from a release of Spark +# with Kubernetes support. + +function error { + echo "$@" 1>&2 + exit 1 +} + +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} + +function build { + local BUILD_ARGS + local IMG_PATH + + if [ ! -f "$SPARK_HOME/RELEASE" ]; then + # Set image build arguments accordingly if this is a source repo and not a distribution archive. + IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles + BUILD_ARGS=( + --build-arg + img_path=$IMG_PATH + --build-arg + spark_jars=assembly/target/scala-$SPARK_SCALA_VERSION/jars + ) + else + # Not passed as an argument to docker, but used to validate the Spark directory. + IMG_PATH="kubernetes/dockerfiles" + BUILD_ARGS=() + fi + + if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." + fi + + local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + + docker build "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$DOCKERFILE" . +} + +function push { + docker push "$(image_ref spark)" +} + +function usage { + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; + esac +done + +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) + usage + exit 1 + ;; +esac diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index cf93d41cd77c..8c148359c302 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 4f9e10ca2006..0e491efac918 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -83,6 +83,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { if (versionData != null) { long version = serializer.deserializeLong(versionData); if (version != STORE_VERSION) { + close(); throw new UnsupportedStoreVersionException(); } } else { diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java index 232ee41dd0b1..f4d359234cb9 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -493,7 +493,7 @@ byte[] toKey(Object value, byte prefix) { byte[] key = new byte[bytes * 2 + 2]; long longValue = ((Number) value).longValue(); key[0] = prefix; - key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + key[1] = longValue >= 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; for (int i = 0; i < key.length - 2; i++) { int masked = (int) ((longValue >>> (4 * i)) & 0xF); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java index 9a81f86812cd..1e062437d180 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java @@ -73,7 +73,9 @@ default BaseComparator reverse() { private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); - private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> { + return Integer.valueOf(t1.num).compareTo(t2.num); + }; private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); /** @@ -112,7 +114,8 @@ public void setup() throws Exception { t.key = "key" + i; t.id = "id" + i; t.name = "name" + RND.nextInt(MAX_ENTRIES); - t.num = RND.nextInt(MAX_ENTRIES); + // Force one item to have an integer value of zero to test the fix for SPARK-23103. + t.num = (i != 0) ? (int) RND.nextLong() : 0; t.child = "child" + (i % MIN_ENTRIES); allEntries.add(t); } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 2b07d249d202..b8123ac81d29 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; @@ -74,11 +76,7 @@ public void testReopenAndVersionCheckDb() throws Exception { @Test public void testObjectWriteReadDelete() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); try { db.read(CustomType1.class, t.key); @@ -106,17 +104,9 @@ public void testObjectWriteReadDelete() throws Exception { @Test public void testMultipleObjectWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "key1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; - - CustomType1 t2 = new CustomType1(); - t2.key = "key2"; - t2.id = "id"; - t2.name = "name2"; - t2.child = "child2"; + CustomType1 t1 = createCustomType1(1); + CustomType1 t2 = createCustomType1(2); + t2.id = t1.id; db.write(t1); db.write(t2); @@ -142,11 +132,7 @@ public void testMultipleObjectWriteReadDelete() throws Exception { @Test public void testMultipleTypesWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; + CustomType1 t1 = createCustomType1(1); IntKeyType t2 = new IntKeyType(); t2.key = 2; @@ -188,10 +174,7 @@ public void testMultipleTypesWriteReadDelete() throws Exception { public void testMetadata() throws Exception { assertNull(db.getMetadata(CustomType1.class)); - CustomType1 t = new CustomType1(); - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.setMetadata(t); assertEquals(t, db.getMetadata(CustomType1.class)); @@ -202,11 +185,7 @@ public void testMetadata() throws Exception { @Test public void testUpdate() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.write(t); @@ -222,13 +201,7 @@ public void testUpdate() throws Exception { @Test public void testSkip() throws Exception { for (int i = 0; i < 10; i++) { - CustomType1 t = new CustomType1(); - t.key = "key" + i; - t.id = "id" + i; - t.name = "name" + i; - t.child = "child" + i; - - db.write(t); + db.write(createCustomType1(i)); } KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); @@ -240,6 +213,36 @@ public void testSkip() throws Exception { assertFalse(it.hasNext()); } + @Test + public void testNegativeIndexValues() throws Exception { + List expected = Arrays.asList(-100, -50, 0, 50, 100); + + expected.stream().forEach(i -> { + try { + db.write(createCustomType1(i)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + List results = StreamSupport + .stream(db.view(CustomType1.class).index("int").spliterator(), false) + .map(e -> e.num) + .collect(Collectors.toList()); + + assertEquals(expected, results); + } + + private CustomType1 createCustomType1(int i) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.num = i; + t.child = "child" + i; + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 18cbdadd224a..8ca7733507f1 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c346..8b8f9892847c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bc..60f51125c07f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,6 +99,7 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 9968480ab765..05335df61a66 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9cac7d00cc6b..0bc571874f07 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -171,7 +171,9 @@ private class DownloadCallback implements StreamCallback { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { - channel.write(buf); + while (buf.hasRemaining()) { + channel.write(buf); + } } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b7..386738ece51a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ec2db6e5bb88..564e6583c909 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2d59c71cc375..2f04abe8c7e8 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java index a61ce4fb7241..e83b331391e3 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f7e586ee777e..ba127408e1c5 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a3772a262008..152785473039 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 73577437ac50..62b75ae8aa01 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -38,12 +39,21 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + public static int hashUnsafeBytesBlock(MemoryBlock mb) { + long lengthInBytes = mb.size(); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (int i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) Platform.getByte(base, offset + i); + for (long i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) mb.getByte(i); } return result; } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); + } + + public static int hashUTF8String(UTF8String str) { + return hashUnsafeBytesBlock(str.getMemoryBlock()); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48..54dcadf3a775 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index a6b1f7a16d60..ef0f78d95d1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -32,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { @@ -48,15 +53,25 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); + /** + * MemoryBlock equality check for MemoryBlocks. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEqualsBlock( + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { + return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, + rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); + } + /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { int i = 0; - // check if stars align and we can get both offsets to be aligned + // check if starts align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 2cd39bd60c2a..b74d2de0691d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -33,16 +32,12 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; - private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); - this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -51,11 +46,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return baseObj; + return memory.getBaseObject(); } public long getBaseOffset() { - return baseOffset; + return memory.getBaseOffset(); } /** @@ -69,8 +64,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + for (long off = 0; off < length * WIDTH; off += WIDTH) { + memory.putLong(off, 0); } } @@ -80,7 +75,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + memory.putLong(index * WIDTH, value); } /** @@ -89,6 +84,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return memory.getLong(index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 5e7ee480cafd..aff6e93d647f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,10 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.Platform; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -49,33 +52,70 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + int h1 = hashBytesByIntBlock(base, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(base, offset + i); + int halfWord = base.getByte(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUTF8String(UTF8String str, int seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { + // This is compatible with original and other implementations. + // Use this method for new components after Spark 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (base.getByte(i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByIntBlock(MemoryBlock base, int seed) { + long lengthInBytes = base.size(); assert (lengthInBytes % 4 == 0); int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); + for (long i = 0; i < lengthInBytes; i += 4) { + int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 000000000000..9f238632bc87 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a byte array on Java heap. + */ +public final class ByteArrayMemoryBlock extends MemoryBlock { + + private final byte[] array; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); + } + + public ByteArrayMemoryBlock(long length) { + this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new ByteArrayMemoryBlock(array, this.offset + offset, size); + } + + public byte[] getByteArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } + + @Override + public int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public byte getByte(long offset) { + return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; + } + + @Override + public void putByte(long offset, byte value) { + array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; + } + + @Override + public short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index cc9cc429643a..36caf80888cd 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,16 +23,13 @@ import java.util.LinkedList; import java.util.Map; -import org.apache.spark.unsafe.Platform; - /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ public class HeapMemoryAllocator implements MemoryAllocator { @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap<>(); + private final Map>> bufferPoolsBySize = new HashMap<>(); private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; @@ -47,27 +44,31 @@ private boolean shouldPool(long size) { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { + int numWords = (int) ((size + 7) / 8); + long alignedSize = numWords * 8L; + assert (alignedSize >= size); + if (shouldPool(alignedSize)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool != null) { while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); + final WeakReference arrayReference = pool.pop(); + final long[] array = arrayReference.get(); + if (array != null) { + assert (array.length * 8L >= size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } return memory; } } - bufferPoolsBySize.remove(size); + bufferPoolsBySize.remove(alignedSize); } } } - long[] array = new long[(int) ((size + 7) / 8)]; - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + long[] array = new long[numWords]; + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -76,18 +77,38 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { + assert(memory instanceof OnHeapMemoryBlock); + assert (memory.getBaseObject() != null) : + "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + + "free()"; + final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } - if (shouldPool(size)) { + + // Mark the page as freed (so we can detect double-frees). + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); + + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to null out its reference to the long[] array. + long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); + memory.resetObjAndOffset(); + + long alignedSize = ((size + 7) / 8) * 8; + if (shouldPool(alignedSize)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool == null) { pool = new LinkedList<>(); - bufferPoolsBySize.put(size, pool); + bufferPoolsBySize.put(alignedSize, pool); } - pool.add(new WeakReference<>(memory)); + pool.add(new WeakReference<>(array)); } } else { // Do nothing diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 7b588681d979..38315fb97b46 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index cd1d378bc147..ca7213bbf92d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,42 +22,186 @@ import org.apache.spark.unsafe.Platform; /** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + * A representation of a consecutive memory block in Spark. It defines the common interfaces + * for memory accessing and mutating. */ -public class MemoryBlock extends MemoryLocation { +public abstract class MemoryBlock { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ + public static final int NO_PAGE_NUMBER = -1; - private final long length; + /** + * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager. + * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator + * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM + * before being passed to MemoryAllocator.free() (it is an error to allocate a page in + * TaskMemoryManager and then directly free it in a MemoryAllocator without going through + * the TMM freePage() call). + */ + public static final int FREED_IN_TMM_PAGE_NUMBER = -2; + + /** + * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows + * us to detect double-frees. + */ + public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; + + @Nullable + protected Object obj; + + protected long offset; + + protected long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. + * TaskMemoryManager. This field can be updated using setPageNumber method so that + * this can be modified by the TaskMemoryManager, which lives in a different package. */ - public int pageNumber = -1; + private int pageNumber = NO_PAGE_NUMBER; - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); + protected MemoryBlock(@Nullable Object obj, long offset, long length) { + if (offset < 0 || length < 0) { + throw new IllegalArgumentException( + "Length " + length + " and offset " + offset + "must be non-negative"); + } + this.obj = obj; + this.offset = offset; this.length = length; } + protected MemoryBlock() { + this(null, 0, 0); + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } + + public void resetObjAndOffset() { + this.obj = null; + this.offset = 0; + } + /** * Returns the size of the memory block. */ - public long size() { + public final long size() { return length; } - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + public final void setPageNumber(int pageNum) { + pageNumber = pageNum; + } + + public final int getPageNumber() { + return pageNumber; } /** * Fills the memory block with the specified byte value. */ - public void fill(byte value) { + public final void fill(byte value) { Platform.setMemory(obj, offset, length, value); } + + /** + * Instantiate MemoryBlock for given object type with new offset + */ + public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { + MemoryBlock mb = null; + if (obj instanceof byte[]) { + byte[] array = (byte[])obj; + mb = new ByteArrayMemoryBlock(array, offset, length); + } else if (obj instanceof long[]) { + long[] array = (long[])obj; + mb = new OnHeapMemoryBlock(array, offset, length); + } else if (obj == null) { + // we assume that to pass null pointer means off-heap + mb = new OffHeapMemoryBlock(offset, length); + } else { + throw new UnsupportedOperationException( + "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); + } + return mb; + } + + /** + * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative + * offset from the original offset. The data is not copied. + * If parameters are invalid, an exception is thrown. + */ + public abstract MemoryBlock subBlock(long offset, long size); + + protected void checkSubBlockRange(long offset, long size) { + if (offset < 0 || size < 0) { + throw new ArrayIndexOutOfBoundsException( + "Size " + size + " and offset " + offset + " must be non-negative"); + } + if (offset + size > length) { + throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + + offset + " should not be larger than the length " + length + " in the MemoryBlock"); + } + } + + /** + * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal + * memory access, throw an exception, or etc. + * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as + * JVM object header. The offset is 0-based and is expected as an logical offset in the memory + * block. + */ + public abstract int getInt(long offset); + + public abstract void putInt(long offset, int value); + + public abstract boolean getBoolean(long offset); + + public abstract void putBoolean(long offset, boolean value); + + public abstract byte getByte(long offset); + + public abstract void putByte(long offset, byte value); + + public abstract short getShort(long offset); + + public abstract void putShort(long offset, short value); + + public abstract long getLong(long offset); + + public abstract void putLong(long offset, long value); + + public abstract float getFloat(long offset); + + public abstract void putFloat(long offset, float value); + + public abstract double getDouble(long offset); + + public abstract void putDouble(long offset, double value); + + public static final void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, + dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); + } + + public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { + assert(length <= src.length && length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), + dst.getBaseObject(), dst.getBaseOffset(), length); + } + + public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); + } + + public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 000000000000..3431b08980eb --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock extends MemoryBlock { + public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + + public OffHeapMemoryBlock(long address, long size) { + super(null, address, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OffHeapMemoryBlock(this.offset + offset, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(null, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(null, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(null, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(null, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(null, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(null, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(null, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(null, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(null, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(null, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(null, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(null, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(null, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(null, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java new file mode 100644 index 000000000000..ee42bc27c9c5 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a long array on Java heap. + */ +public final class OnHeapMemoryBlock extends MemoryBlock { + + private final long[] array; + + public OnHeapMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); + } + + public OnHeapMemoryBlock(long size) { + this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OnHeapMemoryBlock(array, this.offset + offset, size); + } + + public long[] getLongArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static OnHeapMemoryBlock fromArray(final long[] array) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + public static OnHeapMemoryBlock fromArray(final long[] array, long size) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public byte getByte(long offset) { + return Platform.getByte(array, this.offset + offset); + } + + @Override + public void putByte(long offset, byte value) { + Platform.putByte(array, this.offset + offset, value); + } + + @Override + public short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 55bcdf1ed7b0..5310bdf2779a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - MemoryBlock memory = new MemoryBlock(null, address, size); + OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,11 +36,25 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert(memory instanceof OffHeapMemoryBlock) : + "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; + if (memory == OffHeapMemoryBlock.NULL) return; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + Platform.freeMemory(memory.offset); + + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to reset its pointer. + memory.resetObjAndOffset(); + // Mark the page as freed (so we can detect double-frees). + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f6..ecd7c19f2c63 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b0d0c44823e6..e91fc4391425 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,10 +29,13 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -50,19 +53,51 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; - private long offset; + private MemoryBlock base; + // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int private int numBytes; - public Object getBaseObject() { return base; } - public long getBaseOffset() { return offset; } + public MemoryBlock getMemoryBlock() { return base; } + public Object getBaseObject() { return base.getBaseObject(); } + public long getBaseOffset() { return base.getBaseOffset(); } - private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6}; + /** + * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which + * indicates the size of the char. See Unicode standard in page 126, Table 3-6: + * http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf + * + * Binary Hex Comments + * 0xxxxxxx 0x00..0x7F Only byte of a 1-byte character encoding + * 10xxxxxx 0x80..0xBF Continuation bytes (1-3 continuation bytes) + * 110xxxxx 0xC0..0xDF First byte of a 2-byte character encoding + * 1110xxxx 0xE0..0xEF First byte of a 3-byte character encoding + * 11110xxx 0xF0..0xF7 First byte of a 4-byte character encoding + * + * As a consequence of the well-formedness conditions specified in + * Table 3-7 (page 126), the following byte values are disallowed in UTF-8: + * C0–C1, F5–FF. + */ + private static byte[] bytesOfCodePointInUTF8 = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F + // Continuation bytes cannot appear as the first byte + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF + 0, 0, // 0xC0..0xC1 - disallowed in UTF-8 + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF + 4, 4, 4, 4, 4, // 0xF0..0xF4 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 + }; private static final boolean IS_LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; @@ -77,7 +112,8 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); } else { return null; } @@ -90,19 +126,13 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); } else { return null; } } - /** - * Creates an UTF8String from given address (base and offset) and length. - */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { - return new UTF8String(base, offset, numBytes); - } - /** * Creates an UTF8String from String. */ @@ -119,16 +149,13 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + public UTF8String(MemoryBlock base) { this.base = base; - this.offset = offset; - this.numBytes = numBytes; + this.numBytes = Ints.checkedCast(base.size()); } // for serialization - public UTF8String() { - this(null, 0, 0); - } + public UTF8String() {} /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -136,7 +163,7 @@ public UTF8String() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - Platform.copyMemory(base, offset, target, targetOffset, numBytes); + base.writeTo(0, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -156,8 +183,9 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + long offset = base.getBaseOffset(); + if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -187,8 +215,9 @@ public void writeTo(OutputStream out) throws IOException { * @param b The first byte of a code point */ private static int numBytesForFirstByte(final byte b) { - final int offset = (b & 0xFF) - 192; - return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + final int offset = b & 0xFF; + byte numBytes = bytesOfCodePointInUTF8[offset]; + return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 } /** @@ -223,12 +252,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); + p = (long) base.getInt(0); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -237,12 +266,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; + p = ((long) base.getInt(0)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -257,12 +286,13 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + long offset = base.getBaseOffset(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -292,7 +322,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -333,14 +363,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return Platform.getByte(base, offset + i); + return base.getByte(i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -467,8 +497,7 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } lastComma = i; @@ -476,8 +505,7 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } return 0; @@ -492,7 +520,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -639,8 +667,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - copyMemory(this.base, this.offset + i, result, - BYTE_ARRAY_OFFSET + result.length - i - len, len); + base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -654,7 +681,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -691,7 +718,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -707,7 +734,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start += 1; @@ -721,7 +748,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start -= 1; @@ -754,7 +781,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -774,7 +801,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -797,15 +824,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -833,13 +860,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -850,22 +877,22 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -904,8 +931,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -913,8 +940,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - copyMemory( - separator.base, separator.offset, + separator.base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1188,7 +1215,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1196,11 +1223,10 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; + MemoryBlock rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); + long left = base.getLong(i); + long right = rbase.getLong(i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1211,7 +1237,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); + int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); if (res != 0) { return res; } @@ -1230,7 +1256,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); } else { return false; } @@ -1286,8 +1312,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, - s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1302,7 +1328,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); + return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); } /** @@ -1365,10 +1391,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + byte[] bytes = new byte[numBytes]; + in.readFully(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } @Override @@ -1380,10 +1406,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - this.offset = BYTE_ARRAY_OFFSET; - this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + numBytes = in.readInt(); + byte[] bytes = new byte[numBytes]; + in.read(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4b141339ec81..583a148b3845 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe; +import org.apache.spark.unsafe.memory.HeapMemoryAllocator; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -62,6 +63,52 @@ public void overlappingCopyMemory() { } } + @Test + public void onHeapMemoryAllocatorPoolingReUsesLongArrays() { + MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject1 = block1.getBaseObject(); + MemoryAllocator.HEAP.free(block1); + MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject2 = block2.getBaseObject(); + Assert.assertSame(baseObject1, baseObject2); + MemoryAllocator.HEAP.free(block2); + } + + @Test + public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + Assert.assertNotNull(block.getBaseObject()); + MemoryAllocator.HEAP.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); + } + + @Test + public void freeingOffHeapMemoryBlockResetsOffset() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + Assert.assertNull(block.getBaseObject()); + Assert.assertNotEquals(0, block.getBaseOffset()); + MemoryAllocator.UNSAFE.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); + } + + @Test(expected = AssertionError.class) + public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + MemoryAllocator.HEAP.free(block); + MemoryAllocator.HEAP.free(block); + } + + @Test(expected = AssertionError.class) + public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + MemoryAllocator.UNSAFE.free(block); + MemoryAllocator.UNSAFE.free(block); + } + @Test public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); @@ -71,9 +118,11 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object onheap1BaseObject = onheap1.getBaseObject(); + long onheap1BaseOffset = onheap1.getBaseOffset(); MemoryAllocator.HEAP.free(onheap1); Assert.assertEquals( - Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + Platform.getByte(onheap1BaseObject, onheap1BaseOffset), MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); Assert.assertEquals( @@ -86,4 +135,26 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryAllocator.UNSAFE.free(offheap); } + + @Test + public void heapMemoryReuse() { + MemoryAllocator heapMem = new HeapMemoryAllocator(); + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // allocate new memory every time. + MemoryBlock onheap1 = heapMem.allocate(513); + Object obj1 = onheap1.getBaseObject(); + heapMem.free(onheap1); + MemoryBlock onheap2 = heapMem.allocate(514); + Assert.assertNotEquals(obj1, onheap2.getBaseObject()); + + // The size is greater than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // reuse the previous memory which has released. + MemoryBlock onheap3 = heapMem.allocate(1024 * 1024 + 1); + Assert.assertEquals(onheap3.size(), 1024 * 1024 + 1); + Object obj3 = onheap3.getBaseObject(); + heapMem.free(onheap3); + MemoryBlock onheap4 = heapMem.allocate(1024 * 1024 + 7); + Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); + Assert.assertEquals(obj3, onheap4.getBaseObject()); + } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f..8c2e98c2bfc5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,14 +20,13 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; public class LongArraySuite { @Test public void basicTest() { - long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index e759cb33b3e6..d7ed005db189 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,6 +22,8 @@ import java.util.Random; import java.util.Set; +import scala.util.hashing.MurmurHash3$; + import org.apache.spark.unsafe.Platform; import org.junit.Assert; import org.junit.Test; @@ -51,6 +53,41 @@ public void testKnownLongInputs() { Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); } + // SPARK-23381 Check whether the hash of the byte array is the same as another implementations + @Test + public void testKnownBytesInputs() { + byte[] test = "test".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0), + Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0)); + byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0), + Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0)); + byte[] te = "te".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0), + Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0)); + byte[] tes = "tes".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0), + Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); + } + + @Test + public void testKnownWordsInputs() { + byte[] bytes = new byte[16]; + long offset = Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < 16; i++) { + bytes[i] = 0; + } + Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = -1; + } + Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = (byte)i; + } + Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java new file mode 100644 index 000000000000..ef5ff8ee70ec --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteOrder; + +import static org.hamcrest.core.StringContains.containsString; + +public class MemoryBlockSuite { + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private void check(MemoryBlock memory, Object obj, long offset, int length) { + memory.setPageNumber(1); + memory.fill((byte)-1); + memory.putBoolean(0, true); + memory.putByte(1, (byte)127); + memory.putShort(2, (short)257); + memory.putInt(4, 0x20000002); + memory.putLong(8, 0x1234567089ABCDEFL); + memory.putFloat(16, 1.0F); + memory.putLong(20, 0x1234567089ABCDEFL); + memory.putDouble(28, 2.0); + MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); + int[] a = new int[2]; + a[0] = 0x12345678; + a[1] = 0x13579BDF; + memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); + byte[] b = new byte[8]; + memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); + + Assert.assertEquals(obj, memory.getBaseObject()); + Assert.assertEquals(offset, memory.getBaseOffset()); + Assert.assertEquals(length, memory.size()); + Assert.assertEquals(1, memory.getPageNumber()); + Assert.assertEquals(true, memory.getBoolean(0)); + Assert.assertEquals((byte)127, memory.getByte(1 )); + Assert.assertEquals((short)257, memory.getShort(2)); + Assert.assertEquals(0x20000002, memory.getInt(4)); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); + Assert.assertEquals(1.0F, memory.getFloat(16), 0); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); + Assert.assertEquals(2.0, memory.getDouble(28), 0); + Assert.assertEquals(true, memory.getBoolean(36)); + Assert.assertEquals((byte)127, memory.getByte(37 )); + Assert.assertEquals((short)257, memory.getShort(38)); + Assert.assertEquals(a[0], memory.getInt(40)); + Assert.assertEquals(a[1], memory.getInt(44)); + if (bigEndianPlatform) { + Assert.assertEquals(a[0], + ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | + ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | + ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); + } else { + Assert.assertEquals(a[0], + ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | + ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | + ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); + } + for (int i = 48; i < memory.size(); i++) { + Assert.assertEquals((byte) -1, memory.getByte(i)); + } + + assert(memory.subBlock(0, memory.size()) == memory); + + try { + memory.subBlock(-8, 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, -8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, length + 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(8, length - 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(length + 8, 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); + } + + @Test + public void testByteArrayMemoryBlock() { + byte[] obj = new byte[56]; + long offset = Platform.BYTE_ARRAY_OFFSET; + int length = obj.length; + + MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = ByteArrayMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new byte[112]; + memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void testOnHeapMemoryBlock() { + long[] obj = new long[7]; + long offset = Platform.LONG_ARRAY_OFFSET; + int length = obj.length * 8; + + MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = OnHeapMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new long[14]; + memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void testOffHeapArrayMemoryBlock() { + MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); + MemoryBlock memory = memoryAllocator.allocate(56); + Object obj = memory.getBaseObject(); + long offset = memory.getBaseOffset(); + int length = 56; + + check(memory, obj, offset, length); + memoryAllocator.free(memory); + + long address = Platform.allocateMemory(112); + memory = new OffHeapMemoryBlock(address, length); + obj = memory.getBaseObject(); + offset = memory.getBaseOffset(); + check(memory, obj, offset, length); + Platform.freeMemory(address); + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b303fa5bc6c..42dda3048070 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,7 +25,8 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; import static org.junit.Assert.*; @@ -51,15 +52,19 @@ private static void checkBasic(String str, int len) { assertTrue(s1.contains(s2)); assertTrue(s2.contains(s1)); - assertTrue(s1.startsWith(s1)); - assertTrue(s1.endsWith(s1)); + assertTrue(s1.startsWith(s2)); + assertTrue(s1.endsWith(s2)); } @Test public void basicTest() { checkBasic("", 0); - checkBasic("hello", 5); + checkBasic("¡", 1); // 2 bytes char + checkBasic("ку", 2); // 2 * 2 bytes chars + checkBasic("hello", 5); // 5 * 1 byte chars checkBasic("大 千 世 界", 7); + checkBasic("︽﹋%", 3); // 3 * 3 bytes chars + checkBasic("\uD83E\uDD19", 1); // 4 bytes char } @Test @@ -508,21 +513,6 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } - @Test - public void writeToOutputStreamUnderflow() throws IOException { - // offset underflow is apparently supported? - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); - - for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) - .writeTo(outputStream); - final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); - assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); - outputStream.reset(); - } - } - @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); @@ -530,7 +520,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -561,7 +551,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -588,26 +578,25 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamIntArray() throws IOException { + public void writeToOutputStreamLongArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(12, length); + assertEquals(16, length); - final int ints = length / 4; - final int[] array = new int[ints]; + final int longs = length / 8; + final long[] array = new long[longs]; - for (int i = 0; i < ints; ++i) { - array[i] = buffer.getInt(); + for (int i = 0; i < longs; ++i) { + array[i] = buffer.getLong(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - fromAddress(array, Platform.INT_ARRAY_OFFSET, length) - .writeTo(outputStream); - assertEquals("大千世界", outputStream.toString("UTF-8")); + new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); + assertEquals("3千大千世界", outputStream.toString("UTF-8")); } @Test @@ -791,4 +780,21 @@ public void trimRightWithTrimString() { assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); } + + @Test + public void skipWrongFirstByte() { + int[] wrongFirstBytes = { + 0x80, 0x9F, 0xBF, // Skip Continuation bytes + 0xC0, 0xC2, // 0xC0..0xC1 - disallowed in UTF-8 + // 0xF5..0xFF - disallowed in UTF-8 + 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, + 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF + }; + byte[] c = new byte[1]; + + for (int i = 0; i < wrongFirstBytes.length; ++i) { + c[0] = (byte)wrongFirstBytes[i]; + assertEquals(fromBytes(c).numChars(), 1); + } + } } diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 62d4176d00f9..48004e812a8b 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -164,7 +164,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { if (length <= 0) return "" if (length <= origin.length) { - if (length <= 0) "" else origin.substring(0, length) + origin.substring(0, length) } else { if (pad.length == 0) return origin val toPad = length - origin.length diff --git a/core/pom.xml b/core/pom.xml index 0a5bd958fc9c..093a9869b6dd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml @@ -95,6 +95,12 @@ org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java index dc3e82647598..2b93385adf10 100644 --- a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -30,4 +30,8 @@ public interface SparkExecutorInfo extends Serializable { int port(); long cacheSize(); int numRunningTasks(); + long usedOnHeapStorageMemory(); + long usedOffHeapStorageMemory(); + long totalOnHeapStorageMemory(); + long totalOffHeapStorageMemory(); } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 3583856d8899..94c5c11b61a5 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,6 +118,18 @@ public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executo onEvent(executorBlacklisted); } + @Override + public void onExecutorBlacklistedForStage( + SparkListenerExecutorBlacklistedForStage executorBlacklistedForStage) { + onEvent(executorBlacklistedForStage); + } + + @Override + public void onNodeBlacklistedForStage( + SparkListenerNodeBlacklistedForStage nodeBlacklistedForStage) { + onEvent(nodeBlacklistedForStage); + } + @Override public final void onExecutorUnblacklisted( SparkListenerExecutorUnblacklisted executorUnblacklisted) { diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5b45d268ace8..0cced9e22295 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream { // whether there is a read ahead task running, private boolean isReading; - // If the remaining data size in the current buffer is below this threshold, - // we issue an async read from the underlying input stream. - private final int readAheadThresholdInBytes; + // whether there is a reader waiting for data. + private AtomicBoolean isWaiting = new AtomicBoolean(false); private final InputStream underlyingInputStream; @@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream { * * @param inputStream The underlying input stream. * @param bufferSizeInBytes The buffer size. - * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead - * threshold, an async read is triggered. */ public ReadAheadInputStream( - InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + InputStream inputStream, int bufferSizeInBytes) { Preconditions.checkArgument(bufferSizeInBytes > 0, "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); - Preconditions.checkArgument(readAheadThresholdInBytes > 0 && - readAheadThresholdInBytes < bufferSizeInBytes, - "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " + - "but the value is " + readAheadThresholdInBytes); activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); - this.readAheadThresholdInBytes = readAheadThresholdInBytes; this.underlyingInputStream = inputStream; activeBuffer.flip(); readAheadBuffer.flip(); @@ -166,12 +159,17 @@ public void run() { // in that case the reader waits for this async read to complete. // So there is no race condition in both the situations. int read = 0; + int off = 0, len = arr.length; Throwable exception = null; try { - while (true) { - read = underlyingInputStream.read(arr); - if (0 != read) break; - } + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); } catch (Throwable ex) { exception = ex; if (ex instanceof Error) { @@ -181,13 +179,12 @@ public void run() { } } finally { stateChangeLock.lock(); + readAheadBuffer.limit(off); if (read < 0 || (exception instanceof EOFException)) { endOfStream = true; } else if (exception != null) { readAborted = true; readException = exception; - } else { - readAheadBuffer.limit(read); } readInProgress = false; signalAsyncReadComplete(); @@ -230,7 +227,10 @@ private void signalAsyncReadComplete() { private void waitForAsyncReadComplete() throws IOException { stateChangeLock.lock(); + isWaiting.set(true); try { + // There is only one reader, and one writer, so the writer should signal only once, + // but a while loop checking the wake up condition is still needed to avoid spurious wakeups. while (readInProgress) { asyncReadComplete.await(); } @@ -239,6 +239,7 @@ private void waitForAsyncReadComplete() throws IOException { iio.initCause(e); throw iio; } finally { + isWaiting.set(false); stateChangeLock.unlock(); } checkReadException(); @@ -246,8 +247,13 @@ private void waitForAsyncReadComplete() throws IOException { @Override public int read() throws IOException { - byte[] oneByteArray = oneByte.get(); - return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + if (activeBuffer.hasRemaining()) { + // short path - just get one byte. + return activeBuffer.get() & 0xFF; + } else { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } } @Override @@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException { if (len == 0) { return 0; } - stateChangeLock.lock(); - try { - return readInternal(b, offset, len); - } finally { - stateChangeLock.unlock(); - } - } - /** - * flip the active and read ahead buffer - */ - private void swapBuffers() { - ByteBuffer temp = activeBuffer; - activeBuffer = readAheadBuffer; - readAheadBuffer = temp; - } - - /** - * Internal read function which should be called only from read() api. The assumption is that - * the stateChangeLock is already acquired in the caller before calling this function. - */ - private int readInternal(byte[] b, int offset, int len) throws IOException { - assert (stateChangeLock.isLocked()); if (!activeBuffer.hasRemaining()) { - waitForAsyncReadComplete(); - if (readAheadBuffer.hasRemaining()) { - swapBuffers(); - } else { - // The first read or activeBuffer is skipped. - readAsync(); + // No remaining in active buffer - lock and switch to write ahead buffer. + stateChangeLock.lock(); + try { waitForAsyncReadComplete(); - if (isEndOfStream()) { - return -1; + if (!readAheadBuffer.hasRemaining()) { + // The first read. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } } + // Swap the newly read read ahead buffer in place of empty active buffer. swapBuffers(); + // After swapping buffers, trigger another async read for read ahead buffer. + readAsync(); + } finally { + stateChangeLock.unlock(); } - } else { - checkReadException(); } len = Math.min(len, activeBuffer.remaining()); activeBuffer.get(b, offset, len); - if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { - readAsync(); - } return len; } + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + @Override public int available() throws IOException { stateChangeLock.lock(); @@ -323,6 +318,11 @@ public long skip(long n) throws IOException { if (n <= 0L) { return 0L; } + if (n <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position((int) n + activeBuffer.position()); + return n; + } stateChangeLock.lock(); long skipped; try { @@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException { if (available() >= n) { // we can skip from the internal buffers int toSkip = (int) n; - if (toSkip <= activeBuffer.remaining()) { - // Only skipping from active buffer is sufficient - activeBuffer.position(toSkip + activeBuffer.position()); - if (activeBuffer.remaining() <= readAheadThresholdInBytes - && !readAheadBuffer.hasRemaining()) { - readAsync(); - } - return n; - } // We need to skip from both active buffer and read ahead buffer toSkip -= activeBuffer.remaining(); + assert(toSkip > 0); // skipping from activeBuffer already handled. activeBuffer.position(0); activeBuffer.flip(); readAheadBuffer.position(toSkip + readAheadBuffer.position()); swapBuffers(); + // Trigger async read to emptied read ahead buffer. readAsync(); return n; } else { diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index e8d3730daa7a..8651a639c07f 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -172,10 +172,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { currentEntry = sortedConsumers.lastEntry(); } List cList = currentEntry.getValue(); - MemoryConsumer c = cList.remove(cList.size() - 1); - if (cList.isEmpty()) { - sortedConsumers.remove(currentEntry.getKey()); - } + MemoryConsumer c = cList.get(cList.size() - 1); try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -185,6 +182,11 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got >= required) { break; } + } else { + cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } } } catch (ClosedByInterruptException e) { // This called by user to kill a task (e.g: speculative task). @@ -309,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -321,17 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != -1) : + assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -353,12 +363,12 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { - assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } @@ -424,6 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911..8f49859746b8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,7 +20,6 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -105,13 +104,7 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -180,10 +173,7 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 717bdd79d47e..254449e95443 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,13 +60,8 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - Platform.copyMemory( - src.getBaseObject(), - src.getBaseOffset() + srcPos * 8L, - dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8L, - length * 8L - ); + MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, + dst.memoryBlock(),dstPos * 8L,length * 8L); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e425879220..02b5de8e128c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -32,6 +32,8 @@ public abstract class RecordComparator { public abstract int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset); + long rightBaseOffset, + int rightBaseLength); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 66118f454159..4fc19b172151 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != + if (!loaded || page.getPageNumber() != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 951d076420ee..717823ebbd32 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,7 +26,6 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -62,12 +61,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { int uaoSize = UnsafeAlignedOffset.getUaoSize(); if (prefixComparisonResult == 0) { final Object baseObject1 = memoryManager.getPage(r1.recordPointer); - // skip length final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); final Object baseObject2 = memoryManager.getPage(r2.recordPointer); - // skip length final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize; - return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); } else { return prefixComparisonResult; } @@ -124,7 +124,7 @@ public UnsafeInMemorySorter( int initialSize, boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2), canUseRadixSort); + consumer.allocateArray(initialSize * 2L), canUseRadixSort); } public UnsafeInMemorySorter( @@ -215,12 +215,7 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -347,10 +342,7 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d9f84d10e905..37772f41caa8 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -84,7 +84,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int @Override public LongArray allocate(int length) { - assert (length * 2 <= buffer.size()) : + assert (length * 2L <= buffer.size()) : "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2); return buffer; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index cf4dfde86ca9..ff0dcc259a4a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger { prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); + left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(), + right.getBaseObject(), right.getBaseOffset(), right.getRecordLength()); } else { return prefixComparisonResult; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index e2f48e5508af..fb179d07edeb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,10 +72,6 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } - final double readAheadFraction = - SparkEnv.get() == null ? 0.5 : - SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - final boolean readAheadEnabled = SparkEnv.get() != null && SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); @@ -84,7 +80,7 @@ public UnsafeSorterSpillReader( try { if (readAheadEnabled) { this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), - (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + (int) bufferSizeBytes); } else { this.in = serializerManager.wrapStream(blockId, bs); } diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index d430d8c5fb35..6717af3ac4da 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -25,12 +25,18 @@ function getThreadDumpEnabled() { return threadDumpEnabled; } -function formatStatus(status, type) { +function formatStatus(status, type, row) { + if (row.isBlacklisted) { + return "Blacklisted"; + } + if (status) { - return "Active" - } else { - return "Dead" + if (row.blacklistedInStages.length == 0) { + return "Active" + } + return "Active (Blacklisted in Stages: [" + row.blacklistedInStages.join(", ") + "])"; } + return "Dead" } jQuery.extend(jQuery.fn.dataTableExt.oSort, { @@ -415,9 +421,10 @@ $(document).ready(function () { } }, {data: 'hostPort'}, - {data: 'isActive', render: function (data, type, row) { - if (row.isBlacklisted) return "Blacklisted"; - else return formatStatus (data, type); + { + data: 'isActive', + render: function (data, type, row) { + return formatStatus (data, type, row); } }, {data: 'rddBlocks'}, diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2cde66b081a1..abc2ec0fa653 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -108,7 +108,12 @@ $(document).ready(function() { requestedIncomplete = getParameterByName("showIncomplete", searchString); requestedIncomplete = (requestedIncomplete == "true" ? true : false); - $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) { + appParams = { + limit: appLimit, + status: (requestedIncomplete ? "running" : "completed") + }; + + $.getJSON(uiRoot + "/api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { @@ -146,7 +151,7 @@ $(document).ready(function() { "showCompletedColumns": !requestedIncomplete, } - $.get("static/historypage-template.html", function(template) { + $.get(uiRoot + "/static/historypage-template.html", function(template) { var sibling = historySummary.prev(); historySummary.detach(); var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data)); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 0fa1fcf25f8b..f01c567ba58a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -50,4 +50,37 @@ function collapseTable(thisName, table){ // to remember if it's collapsed on each page reload $(function() { collapseTablePageLoad('collapse-aggregated-metrics','aggregated-metrics'); + collapseTablePageLoad('collapse-aggregated-executors','aggregated-executors'); + collapseTablePageLoad('collapse-aggregated-removedExecutors','aggregated-removedExecutors'); + collapseTablePageLoad('collapse-aggregated-workers','aggregated-workers'); + collapseTablePageLoad('collapse-aggregated-activeApps','aggregated-activeApps'); + collapseTablePageLoad('collapse-aggregated-activeDrivers','aggregated-activeDrivers'); + collapseTablePageLoad('collapse-aggregated-completedApps','aggregated-completedApps'); + collapseTablePageLoad('collapse-aggregated-completedDrivers','aggregated-completedDrivers'); + collapseTablePageLoad('collapse-aggregated-runningExecutors','aggregated-runningExecutors'); + collapseTablePageLoad('collapse-aggregated-runningDrivers','aggregated-runningDrivers'); + collapseTablePageLoad('collapse-aggregated-finishedExecutors','aggregated-finishedExecutors'); + collapseTablePageLoad('collapse-aggregated-finishedDrivers','aggregated-finishedDrivers'); + collapseTablePageLoad('collapse-aggregated-runtimeInformation','aggregated-runtimeInformation'); + collapseTablePageLoad('collapse-aggregated-sparkProperties','aggregated-sparkProperties'); + collapseTablePageLoad('collapse-aggregated-systemProperties','aggregated-systemProperties'); + collapseTablePageLoad('collapse-aggregated-classpathEntries','aggregated-classpathEntries'); + collapseTablePageLoad('collapse-aggregated-activeJobs','aggregated-activeJobs'); + collapseTablePageLoad('collapse-aggregated-completedJobs','aggregated-completedJobs'); + collapseTablePageLoad('collapse-aggregated-failedJobs','aggregated-failedJobs'); + collapseTablePageLoad('collapse-aggregated-poolTable','aggregated-poolTable'); + collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); + collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); + collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allSkippedStages','aggregated-allSkippedStages'); + collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); + collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); + collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); + collapseTablePageLoad('collapse-aggregated-completedStages','aggregated-completedStages'); + collapseTablePageLoad('collapse-aggregated-failedStages','aggregated-failedStages'); + collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); + collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); + collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); + collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); + collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); }); \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ca52ecafa2cc..9ea6d2fa2fd9 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -76,6 +76,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { + if (mapSideCombine) { + require(aggregator.isDefined, "Map-side combine without Aggregator specified!") + } override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 9112d93a86b2..63d87b4cd385 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -55,18 +55,18 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors will be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * to count those failures toward task failure limits * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ def killExecutors( executorIds: Seq[String], - replace: Boolean = false, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean = false): Seq[String] /** @@ -81,7 +81,8 @@ private[spark] trait ExecutorAllocationClient { * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = { - val killedExecutors = killExecutors(Seq(executorId)) + val killedExecutors = killExecutors(Seq(executorId), adjustTargetNumExecutors = true, + countFailures = false) killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 2e00dc8b49dd..aa363eeffffb 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,9 +26,10 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** @@ -68,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -81,7 +86,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, listenerBus: LiveListenerBus, - conf: SparkConf) + conf: SparkConf, + blockManagerMaster: BlockManagerMaster) extends Logging { allocationManager => @@ -114,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -151,7 +160,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener + val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. private val executor = @@ -195,8 +204,11 @@ private[spark] class ExecutorAllocationManager( throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeoutS <= 0) { - throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + if (executorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be >= 0!") + } + if (cachedExecutorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.cachedExecutorIdleTimeout must be >= 0!") } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors @@ -204,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -268,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { @@ -331,6 +350,11 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { + // We lower the target number of executors but don't actively kill any yet. Killing is + // controlled separately by an idle timeout. It's still helpful to reduce the target number + // in case an executor just happens to get lost (eg., bad hardware, or the cluster manager + // preempts it) -- in that case, there is no point in trying to immediately get a new + // executor, since we wouldn't even use it yet. client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") @@ -452,7 +476,10 @@ private[spark] class ExecutorAllocationManager( val executorsRemoved = if (testing) { executorIdsToBeRemoved } else { - client.killExecutors(executorIdsToBeRemoved) + // We don't want to change our target number of executors, because we already did that + // when the task backlog decreased. + client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + countFailures = false, force = false) } // [SPARK-21834] killExecutors api reduces the target number of executors. // So we need to update the target with desired value. @@ -572,7 +599,7 @@ private[spark] class ExecutorAllocationManager( // Note that it is not necessary to query the executors since all the cached // blocks we are concerned with are reported to the driver. Note that this // does not include broadcast blocks. - val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId) val now = clock.getTimeMillis() val timeout = { if (hasCachedBlocks) { @@ -607,7 +634,7 @@ private[spark] class ExecutorAllocationManager( * This class is intentionally conservative in its assumptions about the relative ordering * and consistency of events returned by the listener. */ - private class ExecutorAllocationListener extends SparkListener { + private[spark] class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] // Number of running tasks per stage including speculative tasks. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b3..73646051f264 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -632,9 +632,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +643,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +670,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +843,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +860,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 437bbaae1968..c940cb25d478 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -43,17 +43,19 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, and the number of partitions of the - * partitioner is either greater than or is less than and within a single order of - * magnitude of the max number of upstream partitions, choose that one. + * If spark.default.parallelism is set, we'll use the value of SparkContext defaultParallelism + * as the default partitions number, otherwise we'll use the max number of upstream partitions. * - * Otherwise, we use a default HashPartitioner. For the number of partitions, if - * spark.default.parallelism is set, then we'll use the value from SparkContext - * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * When available, we choose the partitioner from rdds with maximum number of partitions. If this + * partitioner is eligible (number of partitions within an order of maximum number of partitions + * in rdds), or has partition number higher than default partitions number - we use this + * partitioner. * - * Unless spark.default.parallelism is set, the number of partitions will be the - * same as the number of partitions in the largest upstream RDD, as this should - * be least likely to cause out-of-memory errors. + * Otherwise, we'll use a new HashPartitioner with the default partitions number. + * + * Unless spark.default.parallelism is set, the number of partitions will be the same as the + * number of partitions in the largest upstream RDD, as this should be least likely to cause + * out-of-memory errors. * * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ @@ -67,31 +69,32 @@ object Partitioner { None } - if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + val defaultNumPartitions = if (rdd.context.conf.contains("spark.default.parallelism")) { + rdd.context.defaultParallelism + } else { + rdds.map(_.partitions.length).max + } + + // If the existing max partitioner is an eligible one, or its partitions number is larger + // than the default number of partitions, use the existing partitioner. + if (hasMaxPartitioner.nonEmpty && (isEligiblePartitioner(hasMaxPartitioner.get, rdds) || + defaultNumPartitions < hasMaxPartitioner.get.getNumPartitions)) { hasMaxPartitioner.get.partitioner.get } else { - if (rdd.context.conf.contains("spark.default.parallelism")) { - new HashPartitioner(rdd.context.defaultParallelism) - } else { - new HashPartitioner(rdds.map(_.partitions.length).max) - } + new HashPartitioner(defaultNumPartitions) } } /** - * Returns true if the number of partitions of the RDD is either greater - * than or is less than and within a single order of magnitude of the - * max number of upstream partitions; - * otherwise, returns false + * Returns true if the number of partitions of the RDD is either greater than or is less than and + * within a single order of magnitude of the max number of upstream partitions, otherwise returns + * false. */ private def isEligiblePartitioner( - hasMaxPartitioner: Option[RDD[_]], + hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = { - if (hasMaxPartitioner.isEmpty) { - return false - } val maxPartitions = rdds.map(_.partitions.length).max - log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + log10(maxPartitions) - log10(hasMaxPartitioner.getNumPartitions) < 1 } } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 4c1dbe3ffb4a..dbfd5a514c18 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -42,148 +42,10 @@ import org.apache.spark.util.Utils * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. * - * Spark currently supports authentication via a shared secret. - * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do - * authentication using the shared secret. This authentication is a basic handshake to - * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.acls.enable', 'spark.ui.view.acls' and - * 'spark.ui.view.acls.groups' control the behavior of the acls. Note that the person who - * started the application always has view access to the UI. - * - * Spark has a set of individual and group modify acls (`spark.modify.acls`) and - * (`spark.modify.acls.groups`) that controls which users and groups have permission to - * modify a single application. This would include things like killing the application. - * By default the person who started the application has modify access. For modify access - * through the UI, you must have a filter that does authentication in place for the modify - * acls to work properly. - * - * Spark also has a set of individual and group admin acls (`spark.admin.acls`) and - * (`spark.admin.acls.groups`) which is a set of users/administrators and admin groups - * who always have permission to view or modify the Spark application. - * - * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. - * - * At this point spark has multiple communication protocols that need to be secured and - * different underlying mechanisms are used depending on the protocol: - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spnego, etc. It also supports multiple different login - * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. - * Since we are using DIGEST-MD5, the shared secret is not passed on the wire - * in plaintext. - * - * We currently support SSL (https) for this communication protocol (see the details - * below). - * - * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default - * Authenticator installed in the SecurityManager to how it does the authentication - * and in this case gets the user name and password from the request. - * - * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 - * as the authentication mechanism. This means the shared secret is not passed - * over the wire in plaintext. - * Note that SASL is pluggable as to what mechanism it uses. We currently use - * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. - * Spark currently supports "auth" for the quality of protection, which means - * the connection does not support integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which - * SPARK could support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to - * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the NioBlockTransferService does asynchronous messages passing, the SASL - * authentication is a bit more complex. A ConnectionManager can be both a client - * and a Server, so for a particular connection it has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to - * match up incoming messages with connections waiting for authentication. - * The ConnectionManager tracks all the sendingConnections using the ConnectionId, - * waits for the response from the server, and does the handshake before sending - * the real message. - * - * The NettyBlockTransferService ensures that SASL authentication is performed - * synchronously prior to any other communication on a connection. This is done in - * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. - * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters - * can be used. Yarn requires a specific AmIpFilter be installed for security to work - * properly. For non-Yarn deployments, users can write a filter to go through their - * organization's normal login service. If an authentication filter is in place then the - * SparkUI can be configured to check the logged in user against the list of users who - * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters - * could be used for logging, encryption, or compression. - * - * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. - * - * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop - * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to - * support different levels of protection. See the Hadoop documentation for more details. Each - * Spark application on YARN gets a different shared secret. - * - * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user - * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web - * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That - * authentication then happens via the ResourceManager Proxy and Spark will use that to do - * authorization against the view acls. - * - * For other Spark deployments, the shared secret must be specified via the - * spark.authenticate.secret config. - * All the nodes (Master and Workers) and the applications need to have the same shared secret. - * This again is not ideal as one user could potentially affect another users application. - * This should be enhanced in the future to provide better protection. - * If the UI needs to be secure, the user needs to install a javax servlet filter to do the - * authentication. Spark will then use that user to compare against the view acls to do - * authorization. If not filter is in place the user is generally null and no authorization - * can take place. - * - * When authentication is being used, encryption can also be enabled by setting the option - * spark.authenticate.enableSaslEncryption to true. This is only supported by communication - * channels that use the network-common library, and can be used as an alternative to SSL in those - * cases. - * - * SSL can be used for encryption for certain communication channels. The user can configure the - * default SSL settings which will be used for all the supported communication protocols unless - * they are overwritten by protocol specific settings. This way the user can easily provide the - * common settings for all the protocols without disabling the ability to configure each one - * individually. - * - * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, - * denote the global configuration for all the supported protocols. In order to override the global - * configuration for the particular protocol, the properties must be overwritten in the - * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global - * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be only`fs` for - * broadcast and file server. - * - * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of - * options that can be specified. - * - * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions - * object parses Spark configuration at a given namespace and builds the common representation - * of SSL settings. SSLOptions is then used to provide protocol-specific SSLContextFactory for - * Jetty. - * - * SSL must be configured on each node and configured for each component involved in - * communication using the particular protocol. In YARN clusters, the key-store can be prepared on - * the client side then distributed and used by the executors as the part of the application - * (YARN allows the user to deploy files before the application is started). - * In standalone deployment, the user needs to provide key-stores and configuration - * options for master and workers. In this mode, the user may allow the executors to use the SSL - * settings inherited from the worker which spawned that executor. It can be accomplished by - * setting `spark.ssl.useNodeLocalConf` to `true`. + * This class implements all of the configuration related to security features described + * in the "Security" document. Please refer to that document for specific features implemented + * here. */ - private[spark] class SecurityManager( sparkConf: SparkConf, val ioEncryptionKey: Option[Array[Byte]] = None) @@ -227,6 +89,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -256,51 +119,6 @@ private[spark] class SecurityManager( // the default SSL configuration - it will be used by all communication layers unless overwritten private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) - // SSL configuration for the file server. This is used by Utils.setupSecureURLConnection(). - val fileServerSSLOptions = getSSLOptions("fs") - val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { - val trustStoreManagers = - for (trustStore <- fileServerSSLOptions.trustStore) yield { - val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() - - try { - val ks = KeyStore.getInstance(KeyStore.getDefaultType) - ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) - - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(ks) - tmf.getTrustManagers - } finally { - input.close() - } - } - - lazy val credulousTrustStoreManagers = Array({ - logWarning("Using 'accept-all' trust manager for SSL connections.") - new X509TrustManager { - override def getAcceptedIssuers: Array[X509Certificate] = null - - override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} - - override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} - }: TrustManager - }) - - require(fileServerSSLOptions.protocol.isDefined, - "spark.ssl.protocol is required when enabling SSL connections.") - - val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.get) - sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) - - val hostVerifier = new HostnameVerifier { - override def verify(s: String, sslSession: SSLSession): Boolean = true - } - - (Some(sslContext.getSocketFactory), Some(hostVerifier)) - } else { - (None, None) - } - def getSSLOptions(module: String): SSLOptions = { val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") @@ -504,6 +322,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -520,19 +344,25 @@ private[spark] class SecurityManager( * * If authentication is disabled, do nothing. * - * In YARN mode, generate a new secret and store it in the current user's credentials. + * In YARN and local mode, generate a new secret and store it in the current user's credentials. * * In other modes, assert that the auth secret is set in the configuration. */ def initializeAuth(): Unit = { + import SparkMasterRegex._ + if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return } - if (sparkConf.get(SparkLauncher.SPARK_MASTER, null) != "yarn") { - require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), - s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") - return + val master = sparkConf.get(SparkLauncher.SPARK_MASTER, "") + master match { + case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => + // Secret generation allowed here + case _ => + require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") + return } val rnd = new SecureRandom() @@ -541,7 +371,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretBytes) + secretKey = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d77303e6fdf8..129956e9f9ff 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -603,13 +603,15 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.rpc", "2.0", "Not used anymore."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", "Please use the new blacklisting options, spark.blacklist.*"), - DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), - DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used anymore"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used anymore"), DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", - "Not used any more. Please use spark.shuffle.service.index.cache.size") + "Not used anymore. Please use spark.shuffle.service.index.cache.size"), + DeprecatedConfig("spark.yarn.credentials.file.retention.count", "2.4.0", "Not used anymore."), + DeprecatedConfig("spark.yarn.credentials.file.retention.days", "2.4.0", "Not used anymore.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -640,9 +642,9 @@ private[spark] object SparkConf extends Logging { translation = s => s"${s.toLong * 10}s")), "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), - "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${(s.toDouble * 1000).toInt}k")), + "spark.kryoserializer.buffer" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( @@ -748,7 +750,7 @@ private[spark] object SparkConf extends Logging { } if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) { logWarning( - s"The configuration key $key is not supported any more " + + s"The configuration key $key is not supported anymore " + s"because Spark doesn't use Akka since 2.0") } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 31f3cb9dfa0a..5e8595603cc9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -533,7 +534,8 @@ class SparkContext(config: SparkConf) extends Logging { schedulerBackend match { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( - schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, + _env.blockManager.master)) case _ => None } @@ -1632,6 +1634,8 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * + * This is not supported when dynamic allocation is turned on. + * * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to @@ -1643,7 +1647,10 @@ class SparkContext(config: SparkConf) extends Logging { def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(executorIds, replace = false, force = true).nonEmpty + require(executorAllocationManager.isEmpty, + "killExecutors() unsupported with Dynamic Allocation turned on") + b.killExecutors(executorIds, adjustTargetNumExecutors = true, countFailures = false, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false @@ -1681,7 +1688,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(Seq(executorId), replace = true, force = true).nonEmpty + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = false, countFailures = true, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false @@ -1715,7 +1723,13 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + val rddStorageInfo = statusStore.asOption(statusStore.rdd(rddId)) + rddInfo.numCachedPartitions = rddStorageInfo.map(_.numCachedPartitions).getOrElse(0) + rddInfo.memSize = rddStorageInfo.map(_.memoryUsed).getOrElse(0L) + rddInfo.diskSize = rddStorageInfo.map(_.diskUsed).getOrElse(0L) + } rddInfos.filter(_.isCached) } @@ -1726,17 +1740,6 @@ class SparkContext(config: SparkConf) extends Logging { */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - @deprecated("This method may change or be removed in a future release.", "2.2.0") - def getExecutorStorageStatus: Array[StorageStatus] = { - assertNotStopped() - env.blockManager.master.getStorageStatus - } - /** * :: DeveloperApi :: * Return pools for fair scheduler @@ -2276,7 +2279,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Clean a closure to make it ready to be serialized and send to tasks + * Clean a closure to make it ready to be serialized and sent to tasks * (removes unreferenced variables in $outer's, updates REPL variables) * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 70865cb58c57..815237eba017 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -97,7 +97,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore } /** - * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. + * Returns information of all known executors, including host, port, cacheSize, numRunningTasks + * and memory metrics. */ def getExecutorInfos: Array[SparkExecutorInfo] = { store.executorList(true).map { exec => @@ -113,7 +114,11 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore host, port, cachedMem, - exec.activeTasks) + exec.activeTasks, + exec.memoryMetrics.map(_.usedOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.usedOnHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOnHeapStorageMemory).getOrElse(0L)) }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index c1f24a637778..6a888c1e9e77 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -38,5 +38,9 @@ private class SparkExecutorInfoImpl( val host: String, val port: Int, val cacheSize: Long, - val numRunningTasks: Int) + val numRunningTasks: Int, + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) extends SparkExecutorInfo diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4f..69739745aa6c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb0..cccd3ea457ba 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 93e7ee3d2a40..b5c4c705dcbc 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,7 +22,7 @@ import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate -import java.util.Arrays +import java.util.{Arrays, Properties} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ @@ -35,6 +35,7 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} +import org.apache.log4j.PropertyConfigurator import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -256,6 +257,29 @@ private[spark] object TestUtils { s"Can't find $numExecutors executors before $timeout milliseconds elapsed") } + /** + * config a log4j properties used for testsuite + */ + def configTestLog4j(level: String): Unit = { + val pro = new Properties() + pro.put("log4j.rootLogger", s"$level, console") + pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") + pro.put("log4j.appender.console.target", "System.err") + pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") + pro.put("log4j.appender.console.layout.ConversionPattern", + "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") + PropertyConfigurator.configure(pro) + } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1ec0e717fac2..f075a7e0eb0b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -37,14 +37,16 @@ private[spark] object PythonEvalType { val SQL_BATCHED_UDF = 100 - val SQL_PANDAS_SCALAR_UDF = 200 - val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_SCALAR_PANDAS_UDF = 200 + val SQL_GROUPED_MAP_PANDAS_UDF = 201 + val SQL_GROUPED_AGG_PANDAS_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" - case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" - case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" + case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index f53c6178047f..2340580b54f6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, EOFException, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.charset.StandardCharsets import java.util.Arrays @@ -34,10 +34,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String import PythonWorkerFactory._ - // Because forking processes from Java is expensive, we prefer to launch a single Python daemon - // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently - // only works on UNIX-based systems now because it uses signals for child management, so we can - // also fall back to launching workers (pyspark/worker.py) directly. + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon, + // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon + // currently only works on UNIX-based systems now because it uses signals for child management, + // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. val useDaemon = { val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) @@ -45,6 +45,28 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled } + // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are + // for very advanced users and they are experimental. This should be considered + // as expert-only option, and shouldn't be used before knowing what it means exactly. + + // This configuration indicates the module to run the daemon to execute its Python workers. + val daemonModule = SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value => + logInfo( + s"Python daemon module in PySpark is set to [$value] in 'spark.python.daemon.module', " + + "using this to start the daemon up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is enabled and the platform is not Windows.") + value + }.getOrElse("pyspark.daemon") + + // This configuration indicates the module to run each Python worker. + val workerModule = SparkEnv.get.conf.getOption("spark.python.worker.module").map { value => + logInfo( + s"Python worker module in PySpark is set to [$value] in 'spark.python.worker.module', " + + "using this to start the worker up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is disabled or the platform is Windows.") + value + }.getOrElse("pyspark.worker") + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -74,8 +96,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself - * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + * Connect to a worker launched through pyspark/daemon.py (by default), which forks python + * processes itself to avoid the high cost of forking from Java. This currently only works + * on UNIX-based systems. */ private def createThroughDaemon(): Socket = { @@ -108,7 +131,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Launch a worker by executing worker.py directly and telling it to connect to us. + * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ private def createSimpleWorker(): Socket = { var serverSocket: ServerSocket = null @@ -116,7 +139,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -159,7 +182,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) + val command = Arrays.asList(pythonExec, "-m", daemonModule) + val pb = new ProcessBuilder(command) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -168,7 +192,29 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() + try { + daemonPort = in.readInt() + } catch { + case _: EOFException => + throw new SparkException(s"No port number in $daemonModule's stdout") + } + + // test that the returned port number is within a valid range. + // note: this does not cover the case where the port number + // is arbitrary data but is also coincidentally within range + if (daemonPort < 1 || daemonPort > 0xffff) { + val exceptionMessage = f""" + |Bad data in $daemonModule's standard output. Invalid port number: + | $daemonPort (0x$daemonPort%08x) + |Python command to execute the daemon was: + | ${command.asScala.mkString(" ")} + |Check that you don't have any unexpected modules or libraries in + |your PYTHONPATH: + | $pythonPath + |Also, check if you have a sitecustomize.py module in your python path, + |or in your python installation, that is printing to standard output""" + throw new SparkException(exceptionMessage.stripMargin) + } // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2..8d7a4a353a79 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag +import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging @@ -52,6 +54,10 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) + private[broadcast] val cachedValues = { + new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK) + } + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7aecd3c9668e..e125095cf477 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -206,36 +206,50 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { - setConf(SparkEnv.get.conf) - val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId) match { - case Some(blockResult) => - if (blockResult.data.hasNext) { - val x = blockResult.data.next().asInstanceOf[T] - releaseLock(broadcastId) - x - } else { - throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") - } - case None => - logInfo("Started reading broadcast variable " + id) - val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks() - logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - - try { - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + + Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + + if (x != null) { + broadcastCache.put(broadcastId, x) + } + + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } - obj - } finally { - blocks.foreach(_.dispose()) - } + case None => + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + + if (obj != null) { + broadcastCache.put(broadcastId, obj) + } + + obj + } finally { + blocks.foreach(_.dispose()) + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ecc82d7ac800..178bdcfccb60 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -18,21 +18,24 @@ package org.apache.spark.deploy import java.io.File +import java.net.URI import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.{MutableURLClassLoader, Utils} -private[deploy] object DependencyUtils { +private[deploy] object DependencyUtils extends Logging { def resolveMavenDependencies( packagesExclusions: String, packages: String, repositories: String, - ivyRepoPath: String): String = { + ivyRepoPath: String, + ivySettingsPath: Option[String]): String = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { packagesExclusions.split(",") @@ -40,10 +43,12 @@ private[deploy] object DependencyUtils { Nil } // Create the IvySettings, either load from file or build defaults - val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + val ivySettings = ivySettingsPath match { + case Some(path) => + SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath)) + + case None => + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) } SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) @@ -71,7 +76,7 @@ private[deploy] object DependencyUtils { def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { if (jars != null) { for (jar <- jars.split(",")) { - SparkSubmit.addJarToClasspath(jar, loader) + addJarToClasspath(jar, loader) } } } @@ -137,16 +142,56 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") Utils.stringToSeq(paths).flatMap { path => - val uri = Utils.resolveURI(path) - uri.getScheme match { - case "local" | "http" | "https" | "ftp" => Array(path) - case _ => - val fs = FileSystem.get(uri, hadoopConf) - Option(fs.globStatus(new Path(uri))).map { status => - status.filter(_.isFile).map(_.getPath.toUri.toString) - }.getOrElse(Array(path)) + val (base, fragment) = splitOnFragment(path) + (resolveGlobPath(base, hadoopConf), fragment) match { + case (resolved, Some(_)) if resolved.length > 1 => throw new SparkException( + s"${base.toString} resolves ambiguously to multiple files: ${resolved.mkString(",")}") + case (resolved, Some(namedAs)) => resolved.map(_ + "#" + namedAs) + case (resolved, _) => resolved } }.mkString(",") } + def addJarToClasspath(localJar: String, loader: MutableURLClassLoader): Unit = { + val uri = Utils.resolveURI(localJar) + uri.getScheme match { + case "file" | "local" => + val file = new File(uri.getPath) + if (file.exists()) { + loader.addURL(file.toURI.toURL) + } else { + logWarning(s"Local jar $file does not exist, skipping.") + } + case _ => + logWarning(s"Skip remote jar $uri.") + } + } + + /** + * Merge a sequence of comma-separated file lists, some of which may be null to indicate + * no files, into a single comma-separated string. + */ + def mergeFileLists(lists: String*): String = { + val merged = lists.filterNot(StringUtils.isBlank) + .flatMap(Utils.stringToSeq) + if (merged.nonEmpty) merged.mkString(",") else null + } + + private def splitOnFragment(path: String): (URI, Option[String]) = { + val uri = Utils.resolveURI(path) + val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) + (withoutFragment, Option(uri.getFragment)) + } + + private def resolveGlobPath(uri: URI, hadoopConf: Configuration): Array[String] = { + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(uri.toString) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(uri.toString)) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e14f9845e6db..8353e64a619c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -111,7 +112,9 @@ class SparkHadoopUtil extends Logging { * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { - SparkHadoopUtil.newConfiguration(conf) + val hadoopConf = SparkHadoopUtil.newConfiguration(conf) + hadoopConf.addResource(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE) + hadoopConf } /** @@ -144,7 +147,8 @@ class SparkHadoopUtil extends Logging { private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) val creds = deserialize(tokens) - logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + logInfo("Updating delegation tokens for current user.") + logDebug(s"Adding/updating delegation tokens ${dumpTokens(creds)}") addCurrentUserCredentials(creds) } @@ -319,19 +323,6 @@ class SparkHadoopUtil extends Logging { } } - /** - * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. - * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. - */ - private[spark] def getConfBypassingFSCache( - hadoopConf: Configuration, - scheme: String): Configuration = { - val newConf = new Configuration(hadoopConf) - val confKey = s"fs.${scheme}.impl.disable.cache" - newConf.setBoolean(confKey, true) - newConf - } - /** * Dump the credentials' tokens to string values. * @@ -435,19 +426,27 @@ object SparkHadoopUtil { */ private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + /** + * Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the + * cluster's Hadoop config. It is up to the Spark code launching the application to create + * this file if it's desired. If the file doesn't exist, it will just be ignored. + */ + private[spark] val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" + def get: SparkHadoopUtil = instance /** - * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date - * when a given fraction of the duration until the expiration date has passed. - * Formula: current time + (fraction * (time until expiration)) + * Given an expiration date for the current set of credentials, calculate the time when new + * credentials should be created. + * * @param expirationDate Drop-dead expiration date - * @param fraction fraction of the time until expiration return - * @return Date when the fraction of the time until expiration has passed + * @param conf Spark configuration + * @return Timestamp when new credentials should be created. */ - private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + private[spark] def nextCredentialRenewalTime(expirationDate: Long, conf: SparkConf): Long = { val ct = System.currentTimeMillis - (ct + (fraction * (expirationDate - ct))).toLong + val ratio = conf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) + (ct + (ratio * (expirationDate - ct))).toLong } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a..427c797755b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ */ private[deploy] object SparkSubmitAction extends Enumeration { type SparkSubmitAction = Value - val SUBMIT, KILL, REQUEST_STATUS = Value + val SUBMIT, KILL, REQUEST_STATUS, PRINT_VERSION = Value } /** @@ -67,78 +67,32 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit extends CommandLineUtils with Logging { +private[spark] class SparkSubmit extends Logging { import DependencyUtils._ + import SparkSubmit._ - // Cluster managers - private val YARN = 1 - private val STANDALONE = 2 - private val MESOS = 4 - private val LOCAL = 8 - private val KUBERNETES = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES - - // Deploy modes - private val CLIENT = 1 - private val CLUSTER = 2 - private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - - // Special primary resource names that represent shells rather than application jars. - private val SPARK_SHELL = "spark-shell" - private val PYSPARK_SHELL = "pyspark-shell" - private val SPARKR_SHELL = "sparkr-shell" - private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" - private val R_PACKAGE_ARCHIVE = "rpkg.zip" - - private val CLASS_NOT_FOUND_EXIT_STATUS = 101 - - // Following constants are visible for testing. - private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.yarn.YarnClusterApplication" - private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() - private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() - private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" - - // scalastyle:off println - private[spark] def printVersionAndExit(): Unit = { - printStream.println("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - printStream.println("Using Scala %s, %s, %s".format( - Properties.versionString, Properties.javaVmName, Properties.javaVersion)) - printStream.println("Branch %s".format(SPARK_BRANCH)) - printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) - printStream.println("Revision %s".format(SPARK_REVISION)) - printStream.println("Url %s".format(SPARK_REPO_URL)) - printStream.println("Type --help for more information.") - exitFn(0) - } - // scalastyle:on println - - override def main(args: Array[String]): Unit = { + def doSubmit(args: Array[String]): Unit = { // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to // be reset before the application starts. val uninitLog = initializeLogIfNecessary(true, silent = true) - val appArgs = new SparkSubmitArguments(args) + val appArgs = parseArguments(args) if (appArgs.verbose) { - // scalastyle:off println - printStream.println(appArgs) - // scalastyle:on println + logInfo(appArgs.toString) } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.PRINT_VERSION => printVersion() } } + protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) + } + /** * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ @@ -156,6 +110,24 @@ object SparkSubmit extends CommandLineUtils with Logging { .requestSubmissionStatus(args.submissionToRequestStatusFor) } + /** Print version information to the log. */ + private def printVersion(): Unit = { + logInfo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + logInfo("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) + logInfo(s"Branch $SPARK_BRANCH") + logInfo(s"Compiled by user $SPARK_BUILD_USER on $SPARK_BUILD_DATE") + logInfo(s"Revision $SPARK_REVISION") + logInfo(s"Url $SPARK_REPO_URL") + logInfo("Type --help for more information.") + } + /** * Submit the application using the provided parameters. * @@ -185,10 +157,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { - // scalastyle:off println - printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - // scalastyle:on println - exitFn(1) + error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") } else { throw e } @@ -210,14 +179,11 @@ object SparkSubmit extends CommandLineUtils with Logging { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { - // scalastyle:off println - printStream.println("Running Spark using the REST application submission protocol.") - // scalastyle:on println - doRunMain() + logInfo("Running Spark using the REST application submission protocol.") } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => - printWarning(s"Master endpoint ${args.master} was not a REST server. " + + logWarning(s"Master endpoint ${args.master} was not a REST server. " + "Falling back to legacy submission gateway instead.") args.useRest = false submit(args, false) @@ -255,7 +221,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val clusterManager: Int = args.master match { case "yarn" => YARN case "yarn-client" | "yarn-cluster" => - printWarning(s"Master ${args.master} is deprecated since 2.0." + + logWarning(s"Master ${args.master} is deprecated since 2.0." + " Please use master \"yarn\" with specified deploy mode instead.") YARN case m if m.startsWith("spark") => STANDALONE @@ -263,7 +229,7 @@ object SparkSubmit extends CommandLineUtils with Logging { case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") + error("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -271,7 +237,9 @@ object SparkSubmit extends CommandLineUtils with Logging { var deployMode: Int = args.deployMode match { case "client" | null => CLIENT case "cluster" => CLUSTER - case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 + case _ => + error("Deploy mode must be either client or cluster") + -1 } // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both @@ -283,16 +251,16 @@ object SparkSubmit extends CommandLineUtils with Logging { deployMode = CLUSTER args.master = "yarn" case ("yarn-cluster", "client") => - printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + error("Client deploy mode is not compatible with master \"yarn-cluster\"") case ("yarn-client", "cluster") => - printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + error("Cluster deploy mode is not compatible with master \"yarn-client\"") case (_, mode) => args.master = "yarn" } // Make sure YARN is included in our build if we're trying to use it if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load YARN classes. " + "This copy of Spark may not have been compiled with YARN support.") } @@ -302,7 +270,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.master = Utils.checkAndGetK8sMasterUrl(args.master) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load KUBERNETES classes. " + "This copy of Spark may not have been compiled with KUBERNETES support.") } @@ -311,25 +279,23 @@ object SparkSubmit extends CommandLineUtils with Logging { // Fail fast, the following modes are not supported or applicable (clusterManager, deployMode) match { case (STANDALONE, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + error("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") case (STANDALONE, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + + error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") case (KUBERNETES, _) if args.isPython => - printErrorAndExit("Python applications are currently not supported for Kubernetes.") + error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => - printErrorAndExit("R applications are currently not supported for Kubernetes.") - case (KUBERNETES, CLIENT) => - printErrorAndExit("Client mode is currently not supported for Kubernetes.") + error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => - printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") + error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + error("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + error("Cluster deploy mode is not applicable to Spark SQL shell.") case (_, CLUSTER) if isThriftServer(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") + error("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } @@ -348,7 +314,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( - args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) @@ -481,11 +448,11 @@ object SparkSubmit extends CommandLineUtils with Logging { if (args.isR && clusterManager == YARN) { val sparkRPackagePath = RUtils.localSparkRPackagePath if (sparkRPackagePath.isEmpty) { - printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + error("SPARK_HOME does not exist for R application in YARN mode.") } val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) if (!sparkRPackageFile.exists()) { - printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + error(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString @@ -498,7 +465,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val rPackageFile = RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) if (!rPackageFile.exists()) { - printErrorAndExit("Failed to zip all the built R packages.") + error("Failed to zip all the built R packages.") } val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString @@ -509,12 +476,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // TODO: Support distributing R packages with standalone cluster if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + error("Distributing R packages with standalone cluster is not supported.") } // TODO: Support distributing R packages with mesos cluster if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with mesos cluster is not supported.") + error("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -584,10 +551,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, @@ -786,9 +754,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" - // scalastyle:off println - printStream.println(s"Setting ${key} to ${shortUserName}") - // scalastyle:off println + logInfo(s"Setting ${key} to ${shortUserName}") sparkConf.set(key, shortUserName) } @@ -804,16 +770,14 @@ object SparkSubmit extends CommandLineUtils with Logging { sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { - // scalastyle:off println if (verbose) { - printStream.println(s"Main class:\n$childMainClass") - printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") + logInfo(s"Main class:\n$childMainClass") + logInfo(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") - printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") - printStream.println("\n") + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") + logInfo("\n") } - // scalastyle:on println val loader = if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { @@ -835,23 +799,19 @@ object SparkSubmit extends CommandLineUtils with Logging { mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass.", e) if (childMainClass.contains("thriftserver")) { - // scalastyle:off println - printStream.println(s"Failed to load main class $childMainClass.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load main class $childMainClass.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) case e: NoClassDefFoundError => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass: ${e.getMessage()}") if (e.getMessage.contains("org/apache/hadoop/hive")) { - // scalastyle:off println - printStream.println(s"Failed to load hive class.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load hive class.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { @@ -859,7 +819,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + logWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") } new JavaMainApplication(mainClass) } @@ -878,29 +838,90 @@ object SparkSubmit extends CommandLineUtils with Logging { app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => - findCause(t) match { - case SparkUserAppException(exitCode) => - System.exit(exitCode) - - case t: Throwable => - throw t - } + throw findCause(t) } } - private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { - val uri = Utils.resolveURI(localJar) - uri.getScheme match { - case "file" | "local" => - val file = new File(uri.getPath) - if (file.exists()) { - loader.addURL(file.toURI.toURL) - } else { - printWarning(s"Local jar $file does not exist, skipping.") + /** Throw a SparkException with the given error message. */ + private def error(msg: String): Unit = throw new SparkException(msg) + +} + + +/** + * This entry point is used by the launcher library to start in-process Spark applications. + */ +private[spark] object InProcessSparkSubmit { + + def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() + submit.doSubmit(args) + } + +} + +object SparkSubmit extends CommandLineUtils with Logging { + + // Cluster managers + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES + + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + + // Special primary resource names that represent shells rather than application jars. + private val SPARK_SHELL = "spark-shell" + private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" + + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" + + override def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() { + self => + + override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) { + override protected def logInfo(msg: => String): Unit = self.logInfo(msg) + + override protected def logWarning(msg: => String): Unit = self.logWarning(msg) } - case _ => - printWarning(s"Skip remote jar $uri.") + } + + override protected def logInfo(msg: => String): Unit = printMessage(msg) + + override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + + override def doSubmit(args: Array[String]): Unit = { + try { + super.doSubmit(args) + } catch { + case e: SparkUserAppException => + exitFn(e.exitCode) + case e: SparkException => + printErrorAndExit(e.getMessage()) + } + } + } + + submit.doSubmit(args) } /** @@ -949,17 +970,6 @@ object SparkSubmit extends CommandLineUtils with Logging { res == SparkLauncher.NO_RESOURCE } - /** - * Merge a sequence of comma-separated file lists, some of which may be null to indicate - * no files, into a single comma-separated string. - */ - private[deploy] def mergeFileLists(lists: String*): String = { - val merged = lists.filterNot(StringUtils.isBlank) - .flatMap(_.split(",")) - .mkString(",") - if (merged == "") null else merged - } - } /** Provides utility functions to be used inside SparkSubmit. */ @@ -987,12 +997,12 @@ private[spark] object SparkSubmitUtils { override def toString: String = s"$groupId:$artifactId:$version" } -/** - * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided - * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. - * @param coordinates Comma-delimited string of maven coordinates - * @return Sequence of Maven coordinates - */ + /** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { coordinates.split(",").map { p => val splits = p.replace("/", ":").split(":") @@ -1291,6 +1301,13 @@ private[spark] object SparkSubmitUtils { rule } + def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => throw new SparkException(s"Spark config without '=': $pair") + } + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9db7a1fe3106..0733fdb72caf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{ByteArrayOutputStream, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream} import java.lang.reflect.InvocationTargetException import java.net.URI import java.nio.charset.StandardCharsets @@ -29,7 +29,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source import scala.util.Try +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -40,7 +42,7 @@ import org.apache.spark.util.Utils * The env argument is used for testing. */ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) - extends SparkSubmitArgumentsParser { + extends SparkSubmitArgumentsParser with Logging { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -63,6 +65,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var ivySettingsPath: Option[String] = None var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false @@ -84,8 +87,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() - // scalastyle:off println - if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") + if (verbose) { + logInfo(s"Using properties file: $propertiesFile") + } Option(propertiesFile).foreach { filename => val properties = Utils.getPropertiesFromFile(filename) properties.foreach { case (k, v) => @@ -94,21 +98,16 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Property files may contain sensitive information, so redact before printing if (verbose) { Utils.redact(properties).foreach { case (k, v) => - SparkSubmit.printStream.println(s"Adding default property: $k=$v") + logInfo(s"Adding default property: $k=$v") } } } - // scalastyle:on println defaultProperties } // Set parameters from command line arguments - try { - parse(args.asJava) - } catch { - case e: IllegalArgumentException => - SparkSubmit.printErrorAndExit(e.getMessage()) - } + parse(args.asJava) + // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() // Remove keys that don't start with "spark." from `sparkProperties`. @@ -140,7 +139,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S sparkProperties.foreach { case (k, v) => if (!k.startsWith("spark.")) { sparkProperties -= k - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + logWarning(s"Ignoring non-spark config property: $k=$v") } } } @@ -184,6 +183,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull @@ -213,10 +213,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } catch { case _: Exception => - SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") + error(s"Cannot load main class from JAR $primaryResource") } case _ => - SparkSubmit.printErrorAndExit( + error( s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " + "Please specify a class through --class.") } @@ -233,7 +233,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { - name = Utils.stripDirectory(primaryResource) + name = new File(primaryResource).getName() } // Action should be SUBMIT unless otherwise specified @@ -246,6 +246,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case SUBMIT => validateSubmitArguments() case KILL => validateKillArguments() case REQUEST_STATUS => validateStatusRequestArguments() + case PRINT_VERSION => } } @@ -254,62 +255,61 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") + error("Must specify a primary resource (JAR or Python or R file)") } if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { - SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") + error("No main class set in JAR; please specify one with --class") } if (driverMemory != null && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + error("Driver memory must be a positive number") } if (executorMemory != null && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + error("Executor memory must be a positive number") } if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + error("Executor cores must be a positive number") } if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + error("Total executor cores must be a positive number") } if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { - SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") + error("--py-files given but primary resource is not a Python script") } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { - throw new Exception(s"When running with master '$master' " + + error(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } if (proxyUser != null && principal != null) { - SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + error("Only one of --proxy-user or --principal can be provided.") } } private def validateKillArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( - "Killing submissions is only supported in standalone or Mesos mode!") + error("Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + error("Please specify a submission to kill.") } } private def validateStatusRequestArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( + error( "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + error("Please specify a submission to request status for.") } } @@ -366,7 +366,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case DEPLOY_MODE => if (value != "client" && value != "cluster") { - SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") + error("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value @@ -403,14 +403,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case KILL_SUBMISSION => submissionToKill = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + error(s"Action cannot be both $action and $KILL.") } action = KILL case STATUS => submissionToRequestStatusFor = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + error(s"Action cannot be both $action and $REQUEST_STATUS.") } action = REQUEST_STATUS @@ -442,7 +442,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + val (confName, confValue) = SparkSubmitUtils.parseSparkConfProperty(value) sparkProperties(confName) = confValue case PROXY_USER => @@ -461,15 +461,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S verbose = true case VERSION => - SparkSubmit.printVersionAndExit() + action = SparkSubmitAction.PRINT_VERSION case USAGE_ERROR => printUsageAndExit(1) case _ => - throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + error(s"Unexpected argument '$opt'.") } - true + action != SparkSubmitAction.PRINT_VERSION } /** @@ -480,7 +480,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S */ override protected def handleUnknown(opt: String): Boolean = { if (opt.startsWith("-")) { - SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") + error(s"Unrecognized option '$opt'.") } primaryResource = @@ -499,20 +499,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { - // scalastyle:off println - val outStream = SparkSubmit.printStream if (unknownParam != null) { - outStream.println("Unknown/unsupported param " + unknownParam) + logInfo("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) - outStream.println(command) + logInfo(command) val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB - outStream.println( + logInfo( s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, @@ -594,12 +592,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S ) if (SparkSubmit.isSqlShell(mainClass)) { - outStream.println("CLI options:") - outStream.println(getSqlShellOptions()) + logInfo("CLI options:") + logInfo(getSqlShellOptions()) } - // scalastyle:on println - SparkSubmit.exitFn(exitCode) + throw new SparkUserAppException(exitCode) } /** @@ -653,4 +650,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } + + private def error(msg: String): Unit = throw new SparkException(msg) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 94c80ebd55e7..bf1eeb0c1bf5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,18 +18,21 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, ServiceLoader, UUID} -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions +import java.util.{Date, ServiceLoader} +import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.io.Source import scala.util.Try import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem @@ -58,10 +61,10 @@ import org.apache.spark.util.kvstore._ * * == How new and updated attempts are detected == * - * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any - * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new attempt info entry - * and update or create a matching application info element in the list of applications. + * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any entries in the + * log dir whose size changed since the last scan time are considered new or updated. These are + * replayed to create a new attempt info entry and update or create a matching application info + * element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. * @@ -116,8 +119,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs // and applications between check task and clean task. - private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() - .setNameFormat("spark-history-task-%d").setDaemon(true).build()) + private val pool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-history-task-%d") // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) @@ -126,11 +128,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) + private val fastInProgressParsing = conf.get(FAST_IN_PROGRESS_PARSING) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) @@ -174,7 +179,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (!conf.contains("spark.testing")) { + if (!Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() @@ -275,7 +280,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { Some(load(appId).toApplicationInfo()) } catch { - case e: NoSuchElementException => + case _: NoSuchElementException => None } } @@ -403,51 +408,81 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { - val newLastScanTime = getNewLastScanTime() + val newLastScanTime = clock.getTimeMillis() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - // scan for modified applications, replay and merge them - val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) + + val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // FsHistoryProvider used to generate a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && - recordedFileSize(entry.getPath()) < entry.getLen() + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + } + .filter { entry => + try { + val info = listing.read(classOf[LogInfo], entry.getPath().toString()) + + if (info.appId.isDefined) { + // If the SHS view has a valid application, update the time the file was last seen so + // that the entry is not deleted from the SHS listing. Also update the file size, in + // case the code below decides we don't need to parse the log. + listing.write(info.copy(lastProcessed = newLastScanTime, fileSize = entry.getLen())) + } + + if (info.fileSize < entry.getLen()) { + if (info.appId.isDefined && fastInProgressParsing) { + // When fast in-progress parsing is on, we don't need to re-parse when the + // size changes, but we do need to invalidate any existing UIs. + invalidateUI(info.appId.get, info.attemptId) + false + } else { + true + } + } else { + false + } + } catch { + case _: NoSuchElementException => + // If the file is currently not being tracked by the SHS, add an entry for it and try + // to parse it. This will allow the cleaner code to detect the file as stale later on + // if it was not possible to parse it. + listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, None, None, + entry.getLen())) + entry.getLen() > 0 + } } .sortWith { case (entry1, entry2) => entry1.getModificationTime() > entry2.getModificationTime() } - if (logInfos.nonEmpty) { - logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") + if (updated.nonEmpty) { + logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - var tasks = mutable.ListBuffer[Future[_]]() - - try { - for (file <- logInfos) { - tasks += replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(file) + val tasks = updated.map { entry => + try { + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) }) + } catch { + // let the iteration over the updated entries break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + null } - } catch { - // let the iteration over logInfos break, since an exception on - // replayExecutor.submit (..) indicates the ExecutorService is unable - // to take any more submissions at this time - - case e: Exception => - logError(s"Exception while submitting event log for replay", e) - } + }.filter(_ != null) pendingReplayTasksCount.addAndGet(tasks.size) + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. tasks.foreach { task => try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. task.get() } catch { case e: InterruptedException => @@ -459,28 +494,66 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + // Delete all information about applications whose log files disappeared from storage. + // This is done by identifying the event logs which were not touched by the current + // directory scan. + // + // Only entries with valid applications are cleaned up here. Cleaning up invalid log + // files is done by the periodic cleaner task. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .last(newLastScanTime - 1) + .asScala + .toList + stale.foreach { log => + log.appId.foreach { appId => + cleanAppData(appId, log.attemptId, log.logPath) + listing.delete(classOf[LogInfo], log.logPath) + } + } + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } - private def getNewLastScanTime(): Long = { - val fileName = "." + UUID.randomUUID().toString - val path = new Path(logDir, fileName) - val fos = fs.create(path) - + private def cleanAppData(appId: String, attemptId: Option[String], logPath: String): Unit = { try { - fos.close() - fs.getFileStatus(path).getModificationTime - } catch { - case e: Exception => - logError("Exception encountered when attempting to update last scan time", e) - lastScanTime.get() - } finally { - if (!fs.delete(path, true)) { - logWarning(s"Error deleting ${path}") + val app = load(appId) + val (attempt, others) = app.attempts.partition(_.info.attemptId == attemptId) + + assert(attempt.isEmpty || attempt.size == 1) + val isStale = attempt.headOption.exists { a => + if (a.logPath != new Path(logPath).getName()) { + // If the log file name does not match, then probably the old log file was from an + // in progress application. Just return that the app should be left alone. + false + } else { + val maybeUI = synchronized { + activeUIs.remove(appId -> attemptId) + } + + maybeUI.foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + + diskManager.foreach(_.release(appId, attemptId, delete = true)) + true + } + } + + if (isStale) { + if (others.nonEmpty) { + val newAppInfo = new ApplicationInfoWrapper(app.info, others) + listing.write(newAppInfo) + } else { + listing.delete(classOf[ApplicationInfoWrapper], appId) + } } + } catch { + case _: NoSuchElementException => } } @@ -530,7 +603,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing( + fileStatus: FileStatus, + scanTime: Long, + enableOptimizations: Boolean): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -539,104 +615,169 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() + val appCompleted = isCompleted(logPath.getName()) + val reparseChunkSize = conf.get(END_EVENT_REPARSE_CHUNK_SIZE) + + // Enable halt support in listener if: + // - app in progress && fast parsing enabled + // - skipping to end event is enabled (regardless of in-progress state) + val shouldHalt = enableOptimizations && + ((!appCompleted && fastInProgressParsing) || reparseChunkSize > 0) + val bus = new ReplayListenerBus() - val listener = new AppListingListener(fileStatus, clock) + val listener = new AppListingListener(fileStatus, clock, shouldHalt) bus.addListener(listener) - replay(fileStatus, bus, eventsFilter = eventsFilter) - - listener.applicationInfo.foreach { app => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() - } - } - addListing(app) + logInfo(s"Parsing $logPath for listing data...") + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !appCompleted, eventsFilter) } - listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) - } - /** - * Delete event logs from the log directory according to the clean policy defined by the user. - */ - private[history] def cleanLogs(): Unit = { - var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None - try { - val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - - // Iterate descending over all applications whose oldest attempt happened before maxTime. - iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) - .index("oldestAttempt") - .reverse() - .first(maxTime) - .closeableIterator()) - - iterator.get.asScala.foreach { app => - // Applications may have multiple attempts, some of which may not need to be deleted yet. - val (remaining, toDelete) = app.attempts.partition { attempt => - attempt.info.lastUpdated.getTime() >= maxTime + // If enabled above, the listing listener will halt parsing when there's enough information to + // create a listing entry. When the app is completed, or fast parsing is disabled, we still need + // to replay until the end of the log file to try to find the app end event. Instead of reading + // and parsing line by line, this code skips bytes from the underlying stream so that it is + // positioned somewhere close to the end of the log file. + // + // Because the application end event is written while some Spark subsystems such as the + // scheduler are still active, there is no guarantee that the end event will be the last + // in the log. So, to be safe, the code uses a configurable chunk to be re-parsed at + // the end of the file, and retries parsing the whole log later if the needed data is + // still not found. + // + // Note that skipping bytes in compressed files is still not cheap, but there are still some + // minor gains over the normal log parsing done by the replay bus. + // + // This code re-opens the file so that it knows where it's skipping to. This isn't as cheap as + // just skipping from the current position, but there isn't a a good way to detect what the + // current position is, since the replay listener bus buffers data internally. + val lookForEndEvent = shouldHalt && (appCompleted || !fastInProgressParsing) + if (lookForEndEvent && listener.applicationInfo.isDefined) { + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + val target = fileStatus.getLen() - reparseChunkSize + if (target > 0) { + logInfo(s"Looking for end event; skipping $target bytes from $logPath...") + var skipped = 0L + while (skipped < target) { + skipped += in.skip(target - skipped) + } } - if (remaining.nonEmpty) { - val newApp = new ApplicationInfoWrapper(app.info, remaining) - listing.write(newApp) + val source = Source.fromInputStream(in).getLines() + + // Because skipping may leave the stream in the middle of a line, read the next line + // before replaying. + if (target > 0) { + source.next() } - toDelete.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) + bus.replay(source, logPath.toString, !appCompleted, eventsFilter) + } + } + + logInfo(s"Finished parsing $logPath") + + listener.applicationInfo match { + case Some(app) if !lookForEndEvent || app.attempts.head.info.completed => + // In this case, we either didn't care about the end event, or we found it. So the + // listing data is good. + invalidateUI(app.info.id, app.attempts.head.info.attemptId) + addListing(app) + listing.write(LogInfo(logPath.toString(), scanTime, Some(app.info.id), + app.attempts.head.info.attemptId, fileStatus.getLen())) + + // For a finished log, remove the corresponding "in progress" entry from the listing DB if + // the file is really gone. + if (appCompleted) { + val inProgressLog = logPath.toString() + EventLoggingListener.IN_PROGRESS try { - listing.delete(classOf[LogInfo], logPath.toString()) + // Fetch the entry first to avoid an RPC when it's already removed. + listing.read(classOf[LogInfo], inProgressLog) + if (!fs.isFile(new Path(inProgressLog))) { + listing.delete(classOf[LogInfo], inProgressLog) + } } catch { case _: NoSuchElementException => - logDebug(s"Log info entry for $logPath not found.") - } - try { - fs.delete(logPath, true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) } } - if (remaining.isEmpty) { - listing.delete(app.getClass(), app.id) - } + case Some(_) => + // In this case, the attempt is still not marked as finished but was expected to. This can + // mean the end event is before the configured threshold, so call the method again to + // re-parse the whole log. + logInfo(s"Reparsing $logPath since end event was not found.") + mergeApplicationListing(fileStatus, scanTime, false) + + case _ => + // If the app hasn't written down its app ID to the logs, still record the entry in the + // listing db, with an empty ID. This will make the log eligible for deletion if the app + // does not make progress after the configured max log age. + listing.write(LogInfo(logPath.toString(), scanTime, None, None, fileStatus.getLen())) + } + } + + /** + * Invalidate an existing UI for a given app attempt. See LoadedAppUI for a discussion on the + * UI lifecycle. + */ + private def invalidateUI(appId: String, attemptId: Option[String]): Unit = { + synchronized { + activeUIs.get((appId, attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() } - } catch { - case t: Exception => logError("Exception while cleaning logs", t) - } finally { - iterator.foreach(_.close()) } } /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. - * `ReplayEventsFilter` determines what events are replayed. + * Delete event logs from the log directory according to the clean policy defined by the user. */ - private def replay( - eventLog: FileStatus, - bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { - val logPath = eventLog.getPath() - val isCompleted = !logPath.getName().endsWith(EventLoggingListener.IN_PROGRESS) - logInfo(s"Replaying log path: $logPath") - // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, - // and when we read the file here. That is OK -- it may result in an unnecessary refresh - // when there is no update, but will not result in missing an update. We *must* prevent - // an error the other way -- if we report a size bigger (ie later) than the file that is - // actually read, we may never refresh the app. FileStatus is guaranteed to be static - // after it's created, so we get a file size that is no bigger than what is actually read. - val logInput = EventLoggingListener.openEventLog(logPath, fs) - try { - bus.replay(logInput, logPath.toString, !isCompleted, eventsFilter) - logInfo(s"Finished parsing $logPath") - } finally { - logInput.close() + private[history] def cleanLogs(): Unit = Utils.tryLog { + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 + + val expired = listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .asScala + .toList + expired.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } + + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) + } + + toDelete.foreach { attempt => + logInfo(s"Deleting expired event log for ${attempt.logPath}") + val logPath = new Path(logDir, attempt.logPath) + listing.delete(classOf[LogInfo], logPath.toString()) + cleanAppData(app.id, attempt.info.attemptId, logPath.toString()) + deleteLog(logPath) + } + + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) + } + } + + // Delete log files that don't have a valid application and exceed the configured max age. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .reverse() + .first(maxTime) + .asScala + .toList + stale.foreach { log => + if (log.appId.isEmpty) { + logInfo(s"Deleting invalid / corrupt event log ${log.logPath}") + deleteLog(new Path(log.logPath)) + listing.delete(classOf[LogInfo], log.logPath) + } } } @@ -662,8 +803,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } replayBus.addListener(listener) try { - replay(eventLog, replayBus) + val path = eventLog.getPath() + logInfo(s"Parsing $path to re-build UI...") + Utils.tryWithResource(EventLoggingListener.openEventLog(path, fs)) { in => + replayBus.replay(in, path.toString(), maybeTruncated = !isCompleted(path.toString())) + } trackingStore.close(false) + logInfo(s"Finished parsing $path") } catch { case e: Exception => Utils.tryLogNonFatalError { @@ -703,18 +849,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return the last known size of the given event log, recorded the last time the file - * system scanner detected a change in the file. - */ - private def recordedFileSize(log: Path): Long = { - try { - listing.read(classOf[LogInfo], log.toString()).fileSize - } catch { - case _: NoSuchElementException => 0L - } - } - private def load(appId: String): ApplicationInfoWrapper = { listing.read(classOf[ApplicationInfoWrapper], appId) } @@ -773,11 +907,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") val lease = dm.lease(status.getLen(), isCompressed) val newStorePath = try { - val store = KVUtils.open(lease.tmpPath, metadata) - try { + Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata)) { store => rebuildAppStore(store, status, attempt.info.lastUpdated.getTime()) - } finally { - store.close() } lease.commit(appId, attempt.info.attemptId) } catch { @@ -806,6 +937,21 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) } + private def deleteLog(log: Path): Unit = { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } + } + + private def isCompleted(name: String): Boolean = { + !name.endsWith(EventLoggingListener.IN_PROGRESS) + } + } private[history] object FsHistoryProvider { @@ -832,8 +978,16 @@ private[history] case class FsHistoryProviderMetadata( uiVersion: Long, logDir: String) +/** + * Tracking info for event logs detected in the configured log directory. Tracks both valid and + * invalid logs (e.g. unparseable logs, recorded as logs with no app ID) so that the cleaner + * can know what log files are safe to delete. + */ private[history] case class LogInfo( @KVIndexParam logPath: String, + @KVIndexParam("lastProcessed") lastProcessed: Long, + appId: Option[String], + attemptId: Option[String], fileSize: Long) private[history] class AttemptInfoWrapper( @@ -862,11 +1016,17 @@ private[history] class ApplicationInfoWrapper( } -private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { +private[history] class AppListingListener( + log: FileStatus, + clock: Clock, + haltEnabled: Boolean) extends SparkListener { private val app = new MutableApplicationInfo() private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + private var gotEnvUpdate = false + private var halted = false + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { app.id = event.appId.orNull app.name = event.appName @@ -875,6 +1035,8 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.startTime = new Date(event.time) attempt.lastUpdated = new Date(clock.getTimeMillis()) attempt.sparkUser = event.sparkUser + + checkProgress() } override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { @@ -885,11 +1047,18 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { - val allProperties = event.environmentDetails("Spark Properties").toMap - attempt.viewAcls = allProperties.get("spark.ui.view.acls") - attempt.adminAcls = allProperties.get("spark.admin.acls") - attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + // Only parse the first env update, since any future changes don't have any effect on + // the ACLs set for the UI. + if (!gotEnvUpdate) { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + + gotEnvUpdate = true + checkProgress() + } } override def onOtherEvent(event: SparkListenerEvent): Unit = event match { @@ -906,6 +1075,17 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } } + /** + * Throws a halt exception to stop replay if enough data to create the app listing has been + * read. + */ + private def checkProgress(): Unit = { + if (haltEnabled && !halted && app.id != null && gotEnvUpdate) { + halted = true + throw new HaltReplayException() + } + } + private class MutableApplicationInfo { var id: String = null var name: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 5d62a7d8bebb..6fc12d721e6f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - + ++ +
    @@ -65,7 +66,6 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++
    ++ - ++ ++ } else if (requestedIncomplete) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30..611fa563a7cd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -149,14 +150,18 @@ class HistoryServer( ui: SparkUI, completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") - ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) + handlers.synchronized { + ui.getHandlers.foreach(attachHandler) + addFilters(ui.getHandlers, conf) + } } /** Detach a reconstructed UI from this server. Only valid after bind(). */ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") - ui.getHandlers.foreach(detachHandler) + handlers.synchronized { + ui.getHandlers.foreach(detachHandler) + } provider.onUIDetached(appId, attemptId, ui) } @@ -276,7 +281,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a..25ba9edb9e01 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,24 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + + val FAST_IN_PROGRESS_PARSING = + ConfigBuilder("spark.history.fs.inProgressOptimization.enabled") + .doc("Enable optimized handling of in-progress logs. This option may leave finished " + + "applications that fail to rename their event logs listed as in-progress.") + .booleanConf + .createWithDefault(true) + + val END_EVENT_REPARSE_CHUNK_SIZE = + ConfigBuilder("spark.history.fs.endEventReparseChunkSize") + .doc("How many bytes to parse at the end of log files looking for the end event. " + + "This is used to speed up generation of application listings by skipping unnecessary " + + "parts of event log files. It can be disabled by setting this config to 0.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 68e57b7564ad..f699c75085fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -100,12 +100,29 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
    -

    Executor Summary ({allExecutors.length})

    - {executorsTable} + +

    + + Executor Summary ({allExecutors.length}) +

    +
    +
    + {executorsTable} +
    { if (removedExecutors.nonEmpty) { -

    Removed Executors ({removedExecutors.length})

    ++ - removedExecutorsTable + +

    + + Removed Executors ({removedExecutors.length}) +

    +
    ++ +
    + {removedExecutorsTable} +
    } }
    diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index bc0bf6a1d970..c629937606b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -128,15 +128,31 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Workers ({workers.length})

    - {workerTable} + +

    + + Workers ({workers.length}) +

    +
    +
    + {workerTable} +
    -

    Running Applications ({activeApps.length})

    - {activeAppsTable} + +

    + + Running Applications ({activeApps.length}) +

    +
    +
    + {activeAppsTable} +
    @@ -144,8 +160,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
    -

    Running Drivers ({activeDrivers.length})

    - {activeDriversTable} + +

    + + Running Drivers ({activeDrivers.length}) +

    +
    +
    + {activeDriversTable} +
    } @@ -154,8 +179,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Completed Applications ({completedApps.length})

    - {completedAppsTable} + +

    + + Completed Applications ({completedApps.length}) +

    +
    +
    + {completedAppsTable} +
    @@ -164,8 +198,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
    -

    Completed Drivers ({completedDrivers.length})

    - {completedDriversTable} + +

    + + Completed Drivers ({completedDrivers.length}) +

    +
    +
    + {completedDriversTable} +
    } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f27..3d99d085408c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 116a686fe148..5151df00476f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -64,9 +64,9 @@ private[spark] class HadoopDelegationTokenManager( } private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { - val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), - new HiveDelegationTokenProvider, - new HBaseDelegationTokenProvider) + val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystems)) ++ + safeCreateProvider(new HiveDelegationTokenProvider) ++ + safeCreateProvider(new HBaseDelegationTokenProvider) // Filter out providers for which spark.security.credentials.{service}.enabled is false. providers @@ -75,6 +75,17 @@ private[spark] class HadoopDelegationTokenManager( .toMap } + private def safeCreateProvider( + createFn: => HadoopDelegationTokenProvider): Option[HadoopDelegationTokenProvider] = { + try { + Some(createFn) + } catch { + case t: Throwable => + logDebug(s"Failed to load built in provider.", t) + None + } + } + def isServiceEnabled(serviceName: String): Boolean = { val key = providerEnabledConfig.format(serviceName) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index ece5ce79c650..7249eb85ac7c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils -private[security] class HiveDelegationTokenProvider +private[spark] class HiveDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" @@ -124,9 +124,9 @@ private[security] class HiveDelegationTokenProvider val currentUser = UserGroupInformation.getCurrentUser() val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) - // For some reason the Scala-generated anonymous class ends up causing an - // UndeclaredThrowableException, even if you annotate the method with @throws. - try { + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { realUser.doAs(new PrivilegedExceptionAction[T]() { override def run(): T = fn }) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b19c9904d598..8d6a2b80ef5f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -25,7 +25,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -79,16 +79,21 @@ object DriverWrapper extends Logging { val secMgr = new SecurityManager(sparkConf) val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) - val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = - Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") - .map(sys.props.get(_).orNull) + val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) = + Seq( + "spark.jars.excludes", + "spark.jars.packages", + "spark.jars.repositories", + "spark.jars.ivy", + "spark.jars.ivySettings" + ).map(sys.props.get(_).orNull) val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, - packages, repositories, ivyRepoPath) + packages, repositories, ivyRepoPath, Option(ivySettingsPath)) val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { - SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + DependencyUtils.mergeFileLists(jarsProp, resolvedMavenCoordinates) } else { jarsProp } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc820..dc6a3076a511 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3962d422f81d..563b84934f26 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -441,7 +441,7 @@ private[deploy] class Worker( // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. - val appIds = executors.values.map(_.appId).toSet + val appIds = (executors.values.map(_.appId) ++ drivers.values.map(_.driverId)).toSet val cleanupFuture = concurrent.Future { val appDirs = workDir.listFiles() if (appDirs == null) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index ce84bc4dae32..8b98ae56fc10 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -77,24 +77,60 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
    -

    Running Executors ({runningExecutors.size})

    - {runningExecutorTable} + +

    + + Running Executors ({runningExecutors.size}) +

    +
    +
    + {runningExecutorTable} +
    { if (runningDrivers.nonEmpty) { -

    Running Drivers ({runningDrivers.size})

    ++ - runningDriverTable + +

    + + Running Drivers ({runningDrivers.size}) +

    +
    ++ +
    + {runningDriverTable} +
    } } { if (finishedExecutors.nonEmpty) { -

    Finished Executors ({finishedExecutors.size})

    ++ - finishedExecutorTable + +

    + + Finished Executors ({finishedExecutors.size}) +

    +
    ++ +
    + {finishedExecutorTable} +
    } } { if (finishedDrivers.nonEmpty) { -

    Finished Drivers ({finishedDrivers.size})

    ++ - finishedDriverTable + +

    + + Finished Drivers ({finishedDrivers.size}) +

    +
    ++ +
    + {finishedDriverTable} +
    } }
    diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9b62e4b1b715..48d3630abd1f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -213,13 +213,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } - if (driverConf.contains("spark.yarn.credentials.file")) { - logInfo("Will periodically update credentials from: " + - driverConf.get("spark.yarn.credentials.file")) - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("startCredentialUpdater", classOf[SparkConf]) - .invoke(null, driverConf) - } cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) @@ -234,11 +227,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() - if (driverConf.contains("spark.yarn.credentials.file")) { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("stopCredentialUpdater") - .invoke(null) - } } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2c3a8ef74800..c325222b764b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -35,6 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} @@ -141,8 +142,7 @@ private[spark] class Executor( conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20), RpcUtils.maxMessageSizeBytes(conf)) - // Limit of bytes for total size of results (default is 1GB) - private val maxResultSize = Utils.getMaxResultSize(conf) + private val maxResultSize = conf.get(MAX_RESULT_SIZE) // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] @@ -480,6 +480,19 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { @@ -494,19 +507,6 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - - case _: InterruptedException | NonFatal(_) if - task != null && task.reasonIfKilled.isDefined => - val killReason = task.reasonIfKilled.getOrElse("unknown reason") - logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) - case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b0cd7110a3b4..f27aca03773a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -23,6 +23,7 @@ import java.util.regex.PatternSyntaxException import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.util.Utils private object ConfigHelpers { @@ -45,7 +46,7 @@ private object ConfigHelpers { } def stringToSeq[T](str: String, converter: String => T): Seq[T] = { - str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + Utils.stringToSeq(str).map(converter) } def seqToString[T](v: Seq[T], stringConverter: T => String): String = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index eb12ddf96131..6bb98c37b447 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -38,10 +38,13 @@ package object config { ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .doc("Amount of memory to use for the driver process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.driver.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per driver in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -62,6 +65,7 @@ package object config { .createWithDefault(false) private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .doc("Buffer size to use when writing to output streams, in KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") @@ -81,10 +85,13 @@ package object config { ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .doc("Amount of memory to use per executor process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.executor.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per executor in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -119,6 +126,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") @@ -294,6 +305,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val IGNORE_MISSING_FILES = ConfigBuilder("spark.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") .stringConf .createOptional @@ -353,7 +370,7 @@ package object config { private[spark] val BUFFER_WRITE_CHUNK_SIZE = ConfigBuilder("spark.buffer.write.chunkSize") .internal() - .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + " ChunkedByteBuffer should not larger than Int.MaxValue.") @@ -368,9 +385,9 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") - .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + - "record the size accurately if it's above this config. This helps to prevent OOM by " + - "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .doc("Threshold in bytes above which the size of shuffle blocks in " + + "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + + "by avoiding underestimating shuffle block size when fetch shuffle blocks.") .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) @@ -389,23 +406,23 @@ package object config { private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") - .doc("This configuration limits the number of remote blocks being fetched per reduce task" + - " from a given host port. When a large number of blocks are being requested from a given" + - " address in a single fetch or simultaneously, this could crash the serving executor or" + - " Node Manager. This is especially useful to reduce the load on the Node Manager when" + - " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .doc("This configuration limits the number of remote blocks being fetched per reduce task " + + "from a given host port. When a large number of blocks are being requested from a given " + + "address in a single fetch or simultaneously, this could crash the serving executor or " + + "Node Manager. This is especially useful to reduce the load on the Node Manager when " + + "external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") .intConf .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") - .doc("Remote block will be fetched to disk when size of the block is " + - "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + - "affect both shuffle fetch and block manager remote block fetch. For users who " + - "enabled external shuffle service, this feature can only be worked when external shuffle" + - " service is newer than Spark 2.2.") + .doc("Remote block will be fetched to disk when size of the block is above this threshold " + + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + + "both shuffle fetch and block manager remote block fetch. For users who enabled " + + "external shuffle service, this feature can only be worked when external shuffle" + + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -419,9 +436,9 @@ package object config { private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") - .doc("Size of the in-memory buffer for each shuffle file output stream. " + - "These buffers reduce the number of disk seeks and system calls made " + - "in creating intermediate shuffle files.") + .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + + "otherwise specified. These buffers reduce the number of disk seeks and system calls " + + "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -430,7 +447,7 @@ package object config { private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") .doc("The file system for this buffer size after each partition " + - "is written in unsafe shuffle writer.") + "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -438,7 +455,7 @@ package object config { private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") - .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .doc("The buffer size, in bytes, to use when writing the sorted records to an on-disk file.") .bytesConf(ByteUnit.BYTE) .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") @@ -513,4 +530,21 @@ package object config { .checkValue(v => v > 0, "The threshold should be positive.") .createWithDefault(10000000) + private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize") + .doc("Size limit for results.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1g") + + private[spark] val CREDENTIALS_RENEWAL_INTERVAL_RATIO = + ConfigBuilder("spark.security.credentials.renewalRatio") + .doc("Ratio of the credential's expiration time when Spark should fetch new credentials.") + .doubleConf + .createWithDefault(0.75d) + + private[spark] val CREDENTIALS_RENEWAL_RETRY_WAIT = + ConfigBuilder("spark.security.credentials.retryWait") + .doc("How long to wait before retrying to fetch new credentials after a failure.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1h") + } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af453..e6e9c9e32885 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import org.apache.hadoop.fs._ import org.apache.hadoop.mapreduce._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -28,8 +29,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -131,7 +133,7 @@ abstract class FileCommitProtocol { } -object FileCommitProtocol { +object FileCommitProtocol extends Logging { class TaskCommitMessage(val obj: Any) extends Serializable object EmptyTaskCommitMessage extends TaskCommitMessage(null) @@ -139,10 +141,30 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { + + logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + + s" dynamic=$dynamicPartitionOverwrite") val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + logDebug("Using (String, String, Boolean) constructor") + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + logDebug("Falling back to (String, String) constructor") + require(!dynamicPartitionOverwrite, + "Dynamic Partition Overwrite is enabled but" + + s" the committer ${className} does not have the appropriate constructor") + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9..3e60c50ada59 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,52 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) { + // According to the official hadoop FileSystem API spec, delete op should assume + // the destination is no longer present regardless of return value, thus we do not + // need to double check if finalPartPath exists before rename. + // Also in our case, based on the spec, delete returns false only when finalPartPath + // does not exist. When this happens, we need to take action if parent of finalPartPath + // also does not exist(e.g. the scenario described on SPARK-23815), because + // FileSystem API spec on rename op says the rename dest(finalPartPath) must have + // a parent that exists, otherwise we may get unexpected result on the rename. + fs.mkdirs(finalPartPath.getParent) + } + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +217,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index aaae33ca4e6f..1b049b786023 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -67,13 +67,13 @@ private[spark] abstract class LauncherBackend { } def setAppId(appId: String): Unit = { - if (connection != null) { + if (connection != null && isConnected) { connection.send(new SetAppId(appId)) } } def setState(state: SparkAppHandle.State): Unit = { - if (connection != null && lastState != state) { + if (connection != null && isConnected && lastState != state) { connection.send(new SetState(state)) lastState = state } @@ -114,10 +114,10 @@ private[spark] abstract class LauncherBackend { override def close(): Unit = { try { + _isConnected = false super.close() } finally { onDisconnected() - _isConnected = false } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index c9ed12f4e1bd..13db4985b0b8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -90,7 +90,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = partsScanned * 4 + numPartsToTry = partsScanned * 4L } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 10451a324b0f..94e7d0b38cba 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -266,17 +266,17 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) numCreated += 1 } } - tries = 0 // if we don't have enough partition groups, create duplicates while (numCreated < targetLen) { - val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) - tries += 1 + // Copy the preferred location from a random input partition. + // This helps in avoiding skew when the input partitions are clustered by preferred location. + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs( + rnd.nextInt(partitionLocs.partsWithLocs.length)) val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup addPartToPGroup(nxt_part, pgroup) numCreated += 1 - if (tries >= partitionLocs.partsWithLocs.length) tries = 0 } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2480559a41b7..44895abc7bd4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred._ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -134,6 +135,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. @@ -197,17 +200,24 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) - val inputSplits = if (ignoreEmptySplits) { - allInputSplits.filter(_.getLength > 0) - } else { - allInputSplits - } - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) + try { + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } + val array = new Array[Partition](inputSplits.size) + for (i <- 0 until inputSplits.size) { + array(i) = new HadoopPartition(id, i, inputSplits(i)) + } + array + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - array } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -256,6 +266,12 @@ class HadoopRDD[K, V]( try { inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true @@ -276,6 +292,11 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e4dd1b6a8249..ff66a04859d1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,7 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileInputFormat, FileSplit, InvalidInputException} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark._ @@ -90,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { @@ -124,17 +126,25 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala - val rawSplits = if (ignoreEmptySplits) { - allRowSplits.filter(_.getLength > 0) - } else { - allRowSplits - } - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + try { + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${_conf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - result } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -189,6 +199,12 @@ class NewHadoopRDD[K, V]( _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) _reader } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", @@ -213,6 +229,11 @@ class NewHadoopRDD[K, V]( try { finished = !reader.nextKeyValue } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc92536..0574abdca32a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - @transient var name: String = null + @transient var name: String = _ /** Assign a name to this RDD */ def setName(_name: String): this.type = { @@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag]( // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null + private var dependencies_ : Seq[Dependency[_]] = _ + @transient private var partitions_ : Array[Partition] = _ /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag]( private[spark] def getNarrowAncestors: Seq[RDD[_]] = { val ancestors = new mutable.HashSet[RDD[_]] - def visit(rdd: RDD[_]) { + def visit(rdd: RDD[_]): Unit = { val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]]) val narrowParents = narrowDependencies.map(_.rdd) val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains) @@ -414,6 +414,8 @@ abstract class RDD[T: ClassTag]( * * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. + * + * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207. */ def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) @@ -449,7 +451,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) + var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -951,7 +953,7 @@ abstract class RDD[T: ClassTag]( def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } - (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) + partitions.indices.iterator.flatMap(i => collectPartition(i)) } /** @@ -1338,6 +1340,7 @@ abstract class RDD[T: ClassTag]( // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L + val left = num - buf.size if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1345,13 +1348,12 @@ abstract class RDD[T: ClassTag]( if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } - val left = num - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) @@ -1677,8 +1679,7 @@ abstract class RDD[T: ClassTag]( // an RDD and its parent in every batch, in which case the parent may never be checkpointed // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). private val checkpointAllMarkedAncestors = - Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) - .map(_.toBoolean).getOrElse(false) + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean) /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { @@ -1686,7 +1687,7 @@ abstract class RDD[T: ClassTag]( } /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ - protected[spark] def parent[U: ClassTag](j: Int) = { + protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = { dependencies(j).rdd.asInstanceOf[RDD[U]] } @@ -1754,7 +1755,7 @@ abstract class RDD[T: ClassTag]( * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ - protected def clearDependencies() { + protected def clearDependencies(): Unit = { dependencies_ = null } @@ -1790,7 +1791,7 @@ abstract class RDD[T: ClassTag]( val lastDepStrings = debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) - (frontDepStrings ++ lastDepStrings) + frontDepStrings ++ lastDepStrings } } // The first RDD in the dependency stack has no parents, so no need for a +- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f951591e02a5..a2936d6ad539 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) - try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - } catch { - case e: Exception => - pipe.sink().close() - source.close() - throw e - } + })(catchBlock = { + pipe.sink().close() + source.close() + }) source } @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv( fileDownloadFactory.createClient(host, port) } - private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { @volatile private var error: Throwable = _ def setError(e: Throwable): Unit = { + // This setError callback is invoked by internal RPC threads in order to propagate remote + // exceptions to application-level threads which are reading from this channel. When an + // RPC error occurs, the RPC system will call setError() and then will close the + // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe + // sink will cause `source.read()` operations to return EOF, unblocking the application-level + // reading thread. Thus there is no need to actually call `source.close()` here in the + // onError() callback and, in fact, calling it here would be dangerous because the close() + // would be asynchronous with respect to the read() call and could trigger race-conditions + // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. error = e - source.close() } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + // See the documentation above in setError(): if an RPC error has occurred then setError() + // will be called to propagate the RPC error and then `source`'s corresponding + // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate + // the remote RPC exception (and not any exceptions triggered by the pipe close, such as + // ChannelClosedException), hence this `error != null` check: + case _ if error != null => throw error case Success(bytesRead) => bytesRead - case Failure(readErr) => - if (error != null) { - throw error - } else { - throw readErr - } + case Failure(readErr) => throw readErr } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 7e14938acd8e..c1fedd63f6a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -166,7 +166,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp = System.currentTimeMillis() val previous = new java.util.Date(prevLastReportTimestamp) - logWarning(s"Dropped $droppedEvents events from $name since $previous.") + logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index cd8e61d6d020..30cf75d43ee0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -152,7 +152,8 @@ private[scheduler] class BlacklistTracker ( case Some(a) => logInfo(s"Killing blacklisted executor id $exec " + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") - a.killExecutors(Seq(exec), true, true) + a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, + force = true) case None => logWarning(s"Not attempting to kill blacklisted executor id $exec " + s"since allocation client is not defined.") @@ -209,7 +210,7 @@ private[scheduler] class BlacklistTracker ( updateNextExpiryTime() killBlacklistedExecutor(exec) - val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]()) blacklistedExecsOnNode += exec } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2498d4808e9..78b6b34b5d2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -815,7 +815,8 @@ class DAGScheduler( private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { // Note that there is a chance that this task is launched after the stage is cancelled. // In that case, we wouldn't have the stage anymore in stageIdToStage. - val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } @@ -1015,15 +1016,24 @@ class DAGScheduler( // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = stage match { - case stage: ShuffleMapStage => - JavaUtils.bufferToArray( - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) - case stage: ResultStage => - JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions } taskBinary = sc.broadcast(taskBinaryBytes) @@ -1048,9 +1058,9 @@ class DAGScheduler( stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) - val part = stage.rdd.partitions(id) + val part = partitions(id) stage.pendingPartitions += id - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1058,9 +1068,9 @@ class DAGScheduler( case stage: ResultStage => partitionsToCompute.map { id => val p: Int = stage.partitions(id) - val part = stage.rdd.partitions(p) + val part = partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, + new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1076,23 +1086,22 @@ class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run markStageAsFinished(stage, None) - val debugString = stage match { + stage match { case stage: ShuffleMapStage => - s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) case stage : ResultStage => - s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") } - logDebug(debugString) - submitWaitingChildStages(stage) } } @@ -1245,7 +1254,7 @@ class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) { + if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { // This task was for the currently running attempt of the stage. Since the task // completed successfully from the perspective of the TaskSetManager, mark it as // no longer pending (the TaskSetManager may consider the task complete even @@ -1297,13 +1306,7 @@ class DAGScheduler( shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - // Mark any map-stage jobs waiting on this stage as finished - if (shuffleStage.mapStageJobs.nonEmpty) { - val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) - for (job <- shuffleStage.mapStageJobs) { - markMapStageJobAsFinished(job, stats) - } - } + markMapStageJobsAsFinished(shuffleStage) submitWaitingChildStages(shuffleStage) } } @@ -1324,10 +1327,10 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) - if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is @@ -1423,6 +1426,16 @@ class DAGScheduler( } } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + /** * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index b3a5b1f1e05b..69bc51c1ecf9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,15 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { logEvent(event, flushLogger = true) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b102..ba6387a8f08a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index c9cd662f5709..226c23733c87 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -115,6 +115,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } } catch { + case e: HaltReplayException => + // Just stop replay. case _: EOFException if maybeTruncated => case ioe: IOException => throw ioe @@ -124,8 +126,17 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + override protected def isIgnorableException(e: Throwable): Boolean = { + e.isInstanceOf[HaltReplayException] + } + } +/** + * Exception that can be thrown by listeners to halt replay. This is handled by ReplayListenerBus + * only, and will cause errors if thrown when using other bus implementations. + */ +private[spark] class HaltReplayException extends RuntimeException private[spark] object ReplayListenerBus { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3b677ca9657d..8a112f6a37b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -120,6 +120,24 @@ case class SparkListenerExecutorBlacklisted( taskFailures: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorBlacklistedForStage( + time: Long, + executorId: String, + taskFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerNodeBlacklistedForStage( + time: Long, + hostId: String, + executorFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String) extends SparkListenerEvent @@ -261,6 +279,17 @@ private[spark] trait SparkListenerInterface { */ def onExecutorBlacklisted(executorBlacklisted: SparkListenerExecutorBlacklisted): Unit + /** + * Called when the driver blacklists an executor for a stage. + */ + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit + + /** + * Called when the driver blacklists a node for a stage. + */ + def onNodeBlacklistedForStage(nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit + /** * Called when the driver re-enables a previously blacklisted executor. */ @@ -339,6 +368,12 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorBlacklisted( executorBlacklisted: SparkListenerExecutorBlacklisted): Unit = { } + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit = { } + + def onNodeBlacklistedForStage( + nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit = { } + override def onExecutorUnblacklisted( executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 056c0cbded43..ff19cc65552e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,10 @@ private[spark] trait SparkListenerBus listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage => + listener.onExecutorBlacklistedForStage(executorBlacklistedForStage) + case nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage => + listener.onNodeBlacklistedForStage(nodeBlacklistedForStage) case executorBlacklisted: SparkListenerExecutorBlacklisted => listener.onExecutorBlacklisted(executorBlacklisted) case executorUnblacklisted: SparkListenerExecutorUnblacklisted => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index c513ed36d168..903e25b7986f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - val attemptId: Int, + @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + def attemptNumber(): Int = attemptId + private[spark] def getStatusString: String = { if (completionTime.isDefined) { if (failureReason.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 3c8cab7504c1..3c7af4f6146f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -79,7 +79,7 @@ class StatsReportListener extends SparkListener with Logging { x => info.completionTime.getOrElse(System.currentTimeMillis()) - x ).getOrElse("-") - s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Stage(${info.stageId}, ${info.attemptNumber}); Name: '${info.name}'; " + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + s"Took: $timeTaken msec" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a0..f536fc2a5f0a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981..8e97b3da3382 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index 233781f3d971..b680979a466a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -36,8 +36,12 @@ import org.apache.spark.util.Clock * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. */ -private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock) - extends Logging { +private[scheduler] class TaskSetBlacklist( + private val listenerBus: LiveListenerBus, + val conf: SparkConf, + val stageId: Int, + val stageAttemptId: Int, + val clock: Clock) extends Logging { private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -128,16 +132,23 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, } // Check if enough tasks have failed on the executor to blacklist it for the entire stage. - if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) { + val numFailures = execFailures.numUniqueTasksWithFailures + if (numFailures >= MAX_FAILURES_PER_EXEC_STAGE) { if (blacklistedExecs.add(exec)) { logInfo(s"Blacklisting executor ${exec} for stage $stageId") // This executor has been pushed into the blacklist for this stage. Let's check if it // pushes the whole node into the blacklist. val blacklistedExecutorsOnNode = execsWithFailuresOnNode.filter(blacklistedExecs.contains(_)) - if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) { + val now = clock.getTimeMillis() + listenerBus.post( + SparkListenerExecutorBlacklistedForStage(now, exec, numFailures, stageId, stageAttemptId)) + val numFailExec = blacklistedExecutorsOnNode.size + if (numFailExec >= MAX_FAILED_EXEC_PER_NODE_STAGE) { if (blacklistedNodes.add(host)) { logInfo(s"Blacklisting ${host} for stage $stageId") + listenerBus.post( + SparkListenerNodeBlacklistedForStage(now, host, numFailExec, stageId, stageAttemptId)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c3ed11bfe352..195fc8025e4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -64,8 +64,7 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) - // Limit of bytes for total size of results (default is 1GB) - val maxResultSize = Utils.getMaxResultSize(conf) + val maxResultSize = conf.get(config.MAX_RESULT_SIZE) val speculationEnabled = conf.getBoolean("spark.speculation", false) @@ -74,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -102,7 +103,7 @@ private[spark] class TaskSetManager( private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(sched.sc.listenerBus, conf, stageId, taskSet.stageAttemptId, clock) } } @@ -154,7 +155,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -288,7 +289,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } @@ -755,6 +756,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -765,6 +769,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4d75063fbf1c..5627a557a12f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -147,7 +147,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case KillExecutorsOnHost(host) => scheduler.getExecutorsAliveOnHost(host).foreach { exec => - killExecutors(exec.toSeq, replace = true, force = true) + killExecutors(exec.toSeq, adjustTargetNumExecutors = false, countFailures = false, + force = true) } case UpdateDelegationTokens(newDelegationTokens) => @@ -584,18 +585,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * those failures be counted to task failure limits? * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ final override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") @@ -610,7 +611,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") @@ -618,12 +619,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. val adjustTotalExecutors = - if (!replace) { + if (adjustTargetNumExecutors) { requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) if (requestedTotalExecutors != (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { logDebug( - s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force): + |Executor counts do not match: |requestedTotalExecutors = $requestedTotalExecutors |numExistingExecutors = $numExistingExecutors |numPendingExecutors = $numPendingExecutors diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 538ae05e4eea..72427dd6ce4d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -206,6 +206,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(clazz) } catch { case NonFatal(_) => // do nothing + case _: NoClassDefFoundError if Utils.isTesting => // See SPARK-23422. } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff57c..4103dfb10175 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -90,12 +90,11 @@ private[spark] class BlockStoreShuffleReader[K, C]( dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. - dep.keyOrdering match { + val resultIter = dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. val sorter = @@ -104,9 +103,21 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener(_ => { + sorter.stop() + }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d..d3f1c7ec1bbe 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,8 +18,8 @@ package org.apache.spark.shuffle import java.io._ - -import com.google.common.io.ByteStreams +import java.nio.channels.Channels +import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging @@ -84,7 +84,7 @@ private[spark] class IndexShuffleBlockResolver( */ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8) { + if (index.length() != (blocks + 1) * 8L) { return null } val lengths = new Array[Long](blocks) @@ -141,19 +141,6 @@ private[spark] class IndexShuffleBlockResolver( val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length - out.writeLong(offset) - } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. @@ -166,10 +153,22 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } - indexTmp.delete() } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() + } + if (indexFile.exists()) { indexFile.delete() } @@ -196,11 +195,24 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code + // which is incorrectly using our file descriptor then this code will fetch the wrong offsets + // (which may cause a reducer to be sent a different reducer's data). The explicit position + // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this + // class of issue from re-occurring in the future which is why they are left here even though + // SPARK-22982 is fixed. + val channel = Files.newByteChannel(indexFile.toPath) + channel.position(blockId.reduceId * 8L) + val in = new DataInputStream(Channels.newInputStream(channel)) try { - ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() + val actualPosition = channel.position() + val expectedPosition = blockId.reduceId * 8L + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + + s"expected $expectedPosition but actual position was $actualPosition.") + } new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bfb4dc698e32..d9fad64f34c7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${dependency.serializer.getClass.getName}, does not support object relocation") false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + } else if (dependency.mapSideCombine) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") false } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792bf..274399b9cc1f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { @@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter { def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") false } else { val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 487a782e865e..5ea161cd0d15 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -211,6 +211,46 @@ private[spark] class AppStatusListener( updateBlackListStatus(event.executorId, true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + val now = System.nanoTime() + + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + setStageBlackListStatus(stage, now, event.executorId) + } + liveExecutors.get(event.executorId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) + } + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + val now = System.nanoTime() + + // Implicitly blacklist every available executor for the stage associated with this node + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + val executorIds = liveExecutors.values.filter(_.host == event.hostId).map(_.executorId).toSeq + setStageBlackListStatus(stage, now, executorIds: _*) + } + liveExecutors.values.filter(_.hostname == event.hostId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) + } + } + + private def addBlackListedStageTo(exec: LiveExecutor, stageId: Int, now: Long): Unit = { + exec.blacklistedInStages += stageId + liveUpdate(exec, now) + } + + private def setStageBlackListStatus(stage: LiveStage, now: Long, executorIds: String*): Unit = { + executorIds.foreach { executorId => + val executorStageSummary = stage.executorSummary(executorId) + executorStageSummary.isBlacklisted = true + maybeUpdate(executorStageSummary, now) + } + stage.blackListedExecutors ++= executorIds + maybeUpdate(stage, now) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } @@ -377,6 +417,10 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + + val locality = event.taskInfo.taskLocality.toString() + val count = stage.localitySummary.getOrElse(locality, 0L) + 1L + stage.localitySummary = stage.localitySummary ++ Map(locality -> count) maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -433,7 +477,7 @@ private[spark] class AppStatusListener( } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task, now) + update(task, now, last = true) delta }.orNull @@ -450,7 +494,7 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { - stage.metrics.update(metricsDelta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta) } stage.activeTasks -= 1 stage.completedTasks += completedDelta @@ -486,7 +530,7 @@ private[spark] class AppStatusListener( esummary.failedTasks += failedDelta esummary.killedTasks += killedDelta if (metricsDelta != null) { - esummary.metrics.update(metricsDelta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } maybeUpdate(esummary, now) @@ -529,7 +573,8 @@ private[spark] class AppStatusListener( } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + val maybeStage = + Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -564,12 +609,24 @@ private[spark] class AppStatusListener( stage.executorSummaries.values.foreach(update(_, now)) update(stage, now, last = true) + + val executorIdsForStage = stage.blackListedExecutors + executorIdsForStage.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) + } + } } appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } + private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { + exec.blacklistedInStages -= stageId + liveUpdate(exec, now) + } + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. @@ -603,11 +660,11 @@ private[spark] class AppStatusListener( maybeUpdate(task, now) Option(liveStages.get((sid, sAttempt))).foreach { stage => - stage.metrics.update(delta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, delta) maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) - esummary.metrics.update(delta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, delta) maybeUpdate(esummary, now) } } @@ -689,7 +746,7 @@ private[spark] class AppStatusListener( // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + rdd.setStorageLevel(updatedStorageLevel.get) } val partition = rdd.partition(block.name) @@ -785,7 +842,7 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber), new Function[(Int, Int), LiveStage]() { override def apply(key: (Int, Int)): LiveStage = new LiveStage() }) @@ -813,7 +870,7 @@ private[spark] class AppStatusListener( /** Update a live entity only if it hasn't been updated in the last configured period. */ private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { - if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { update(entity, now) } } @@ -845,8 +902,8 @@ private[spark] class AppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[JobDataWrapper]), - countToDelete.toInt) { j => + val view = kvstore.view(classOf[JobDataWrapper]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j => j.info.status != JobExecutionStatus.RUNNING && j.info.status != JobExecutionStatus.UNKNOWN } toDelete.foreach { j => kvstore.delete(j.getClass(), j.info.jobId) } @@ -858,13 +915,16 @@ private[spark] class AppStatusListener( return } - val stages = KVUtils.viewToSeq(kvstore.view(classOf[StageDataWrapper]), - countToDelete.toInt) { s => + // As the completion time of a skipped stage is always -1, we will remove skipped stages first. + // This is safe since the job itself contains enough information to render skipped stages in the + // UI. + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime") + val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } stages.foreach { s => - val key = s.id + val key = Array(s.info.stageId, s.info.attemptId) kvstore.delete(s.getClass(), key) val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) @@ -884,15 +944,15 @@ private[spark] class AppStatusListener( .asScala tasks.foreach { t => - kvstore.delete(t.getClass(), t.info.taskId) + kvstore.delete(t.getClass(), t.taskId) } // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) .index("stageId") - .first(s.stageId) - .last(s.stageId) + .first(s.info.stageId) + .last(s.info.stageId) .closeableIterator() val hasMoreAttempts = try { @@ -904,23 +964,26 @@ private[spark] class AppStatusListener( } if (!hasMoreAttempts) { - kvstore.delete(classOf[RDDOperationGraphWrapper], s.stageId) + kvstore.delete(classOf[RDDOperationGraphWrapper], s.info.stageId) } + + cleanupCachedQuantiles(key) } } private def cleanupTasks(stage: LiveStage): Unit = { val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { - val stageKey = Array(stage.info.stageId, stage.info.attemptId) - val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) - .last(stageKey) + val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) + val view = kvstore.view(classOf[TaskDataWrapper]) + .index(TaskIndexNames.COMPLETION_TIME) + .parent(stageKey) // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => - !live || t.info.status != TaskState.RUNNING.toString() + !live || t.status != TaskState.RUNNING.toString() } - toDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + toDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-toDelete.size) // If there are more running tasks than the configured limit, delete running tasks. This @@ -929,13 +992,34 @@ private[spark] class AppStatusListener( val remaining = countToDelete - toDelete.size if (remaining > 0) { val runningTasksToDelete = view.max(remaining).iterator().asScala.toList - runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-remaining) } + + // On live applications, cleanup any cached quantiles for the stage. This makes sure that + // quantiles will be recalculated after tasks are replaced with newer ones. + // + // This is not needed in the SHS since caching only happens after the event logs are + // completely processed. + if (live) { + cleanupCachedQuantiles(stageKey) + } } stage.cleaning = false } + private def cleanupCachedQuantiles(stageKey: Array[Int]): Unit = { + val cachedQuantiles = kvstore.view(classOf[CachedQuantile]) + .index("stage") + .first(stageKey) + .last(stageKey) + .asScala + .toList + cachedQuantiles.foreach { q => + kvstore.delete(q.getClass(), q.id) + } + } + /** * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done * asynchronously, this method may return 0 in case enough items have been deleted already. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 5a942f528401..688f25a9fdea 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ -import org.apache.spark.util.Distribution +import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** @@ -95,10 +95,18 @@ private[spark] class AppStatusStore( } def lastStageAttempt(stageId: Int): v1.StageData = { - val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + val it = store.view(classOf[StageDataWrapper]) + .index("stageId") + .reverse() + .first(stageId) + .last(stageId) .closeableIterator() try { - it.next().info + if (it.hasNext()) { + it.next().info + } else { + throw new NoSuchElementException(s"No stage with id $stageId") + } } finally { it.close() } @@ -110,107 +118,238 @@ private[spark] class AppStatusStore( if (details) stageWithDetails(stage) else stage } + def taskCount(stageId: Int, stageAttemptId: Int): Long = { + store.count(classOf[TaskDataWrapper], "stage", Array(stageId, stageAttemptId)) + } + + def localitySummary(stageId: Int, stageAttemptId: Int): Map[String, Long] = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).locality + } + + /** + * Calculates a summary of the task metrics for the given stage attempt, returning the + * requested quantiles for the recorded metrics. + * + * This method can be expensive if the requested quantiles are not cached; the method + * will only cache certain quantiles (every 0.05 step), so it's recommended to stick to + * those to avoid expensive scans of all task data. + */ def taskSummary( stageId: Int, stageAttemptId: Int, - quantiles: Array[Double]): v1.TaskMetricDistributions = { - - val stage = Array(stageId, stageAttemptId) - - val rawMetrics = store.view(classOf[TaskDataWrapper]) - .index("stage") - .first(stage) - .last(stage) - .asScala - .flatMap(_.info.taskMetrics) - .toList - .view - - def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics = - new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics - - def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics = - new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics - - def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics = - new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( - readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles { s => - s.localBlocksFetched + s.remoteBlocksFetched - }, - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics = - new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = - raw.shuffleWriteMetrics - - def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new v1.TaskMetricDistributions( + unsortedQuantiles: Array[Double]): Option[v1.TaskMetricDistributions] = { + val stageKey = Array(stageId, stageAttemptId) + val quantiles = unsortedQuantiles.sorted + + // We don't know how many tasks remain in the store that actually have metrics. So scan one + // metric and count how many valid tasks there are. Use skip() instead of next() since it's + // cheaper for disk stores (avoids deserialization). + val count = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.EXEC_RUN_TIME) + .first(0L) + .closeableIterator() + ) { it => + var _count = 0L + while (it.hasNext()) { + _count += 1 + it.skip(1) + } + _count + } + } + + if (count <= 0) { + return None + } + + // Find out which quantiles are already cached. The data in the store must match the expected + // task count to be considered, otherwise it will be re-scanned and overwritten. + val cachedQuantiles = quantiles.filter(shouldCacheQuantile).flatMap { q => + val qkey = Array(stageId, stageAttemptId, quantileToString(q)) + asOption(store.read(classOf[CachedQuantile], qkey)).filter(_.taskCount == count) + } + + // If there are no missing quantiles, return the data. Otherwise, just compute everything + // to make the code simpler. + if (cachedQuantiles.size == quantiles.size) { + def toValues(fn: CachedQuantile => Double): IndexedSeq[Double] = cachedQuantiles.map(fn) + + val distributions = new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = toValues(_.executorDeserializeTime), + executorDeserializeCpuTime = toValues(_.executorDeserializeCpuTime), + executorRunTime = toValues(_.executorRunTime), + executorCpuTime = toValues(_.executorCpuTime), + resultSize = toValues(_.resultSize), + jvmGcTime = toValues(_.jvmGcTime), + resultSerializationTime = toValues(_.resultSerializationTime), + gettingResultTime = toValues(_.gettingResultTime), + schedulerDelay = toValues(_.schedulerDelay), + peakExecutionMemory = toValues(_.peakExecutionMemory), + memoryBytesSpilled = toValues(_.memoryBytesSpilled), + diskBytesSpilled = toValues(_.diskBytesSpilled), + inputMetrics = new v1.InputMetricDistributions( + toValues(_.bytesRead), + toValues(_.recordsRead)), + outputMetrics = new v1.OutputMetricDistributions( + toValues(_.bytesWritten), + toValues(_.recordsWritten)), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + toValues(_.shuffleReadBytes), + toValues(_.shuffleRecordsRead), + toValues(_.shuffleRemoteBlocksFetched), + toValues(_.shuffleLocalBlocksFetched), + toValues(_.shuffleFetchWaitTime), + toValues(_.shuffleRemoteBytesRead), + toValues(_.shuffleRemoteBytesReadToDisk), + toValues(_.shuffleTotalBlocksFetched)), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + toValues(_.shuffleWriteBytes), + toValues(_.shuffleWriteRecords), + toValues(_.shuffleWriteTime))) + + return Some(distributions) + } + + // Compute quantiles by scanning the tasks in the store. This is not really stable for live + // stages (e.g. the number of recorded tasks may change while this code is running), but should + // stabilize once the stage finishes. It's also slow, especially with disk stores. + val indices = quantiles.map { q => math.min((q * count).toLong, count - 1) } + + def scanTasks(index: String)(fn: TaskDataWrapper => Long): IndexedSeq[Double] = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(index) + .first(0L) + .closeableIterator() + ) { it => + var last = Double.NaN + var currentIdx = -1L + indices.map { idx => + if (idx == currentIdx) { + last + } else { + val diff = idx - currentIdx + currentIdx = idx + if (it.skip(diff - 1)) { + last = fn(it.next()).toDouble + last + } else { + Double.NaN + } + } + }.toIndexedSeq + } + } + + val computedQuantiles = new v1.TaskMetricDistributions( quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGcTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) + executorDeserializeTime = scanTasks(TaskIndexNames.DESER_TIME) { t => + t.executorDeserializeTime + }, + executorDeserializeCpuTime = scanTasks(TaskIndexNames.DESER_CPU_TIME) { t => + t.executorDeserializeCpuTime + }, + executorRunTime = scanTasks(TaskIndexNames.EXEC_RUN_TIME) { t => t.executorRunTime }, + executorCpuTime = scanTasks(TaskIndexNames.EXEC_CPU_TIME) { t => t.executorCpuTime }, + resultSize = scanTasks(TaskIndexNames.RESULT_SIZE) { t => t.resultSize }, + jvmGcTime = scanTasks(TaskIndexNames.GC_TIME) { t => t.jvmGcTime }, + resultSerializationTime = scanTasks(TaskIndexNames.SER_TIME) { t => + t.resultSerializationTime + }, + gettingResultTime = scanTasks(TaskIndexNames.GETTING_RESULT_TIME) { t => + t.gettingResultTime + }, + schedulerDelay = scanTasks(TaskIndexNames.SCHEDULER_DELAY) { t => t.schedulerDelay }, + peakExecutionMemory = scanTasks(TaskIndexNames.PEAK_MEM) { t => t.peakExecutionMemory }, + memoryBytesSpilled = scanTasks(TaskIndexNames.MEM_SPILL) { t => t.memoryBytesSpilled }, + diskBytesSpilled = scanTasks(TaskIndexNames.DISK_SPILL) { t => t.diskBytesSpilled }, + inputMetrics = new v1.InputMetricDistributions( + scanTasks(TaskIndexNames.INPUT_SIZE) { t => t.inputBytesRead }, + scanTasks(TaskIndexNames.INPUT_RECORDS) { t => t.inputRecordsRead }), + outputMetrics = new v1.OutputMetricDistributions( + scanTasks(TaskIndexNames.OUTPUT_SIZE) { t => t.outputBytesWritten }, + scanTasks(TaskIndexNames.OUTPUT_RECORDS) { t => t.outputRecordsWritten }), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_READS) { m => + m.shuffleLocalBytesRead + m.shuffleRemoteBytesRead + }, + scanTasks(TaskIndexNames.SHUFFLE_READ_RECORDS) { t => t.shuffleRecordsRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_BLOCKS) { t => t.shuffleRemoteBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_LOCAL_BLOCKS) { t => t.shuffleLocalBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_READ_TIME) { t => t.shuffleFetchWaitTime }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS) { t => t.shuffleRemoteBytesRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK) { t => + t.shuffleRemoteBytesReadToDisk + }, + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_BLOCKS) { m => + m.shuffleLocalBlocksFetched + m.shuffleRemoteBlocksFetched + }), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_WRITE_SIZE) { t => t.shuffleBytesWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_RECORDS) { t => t.shuffleRecordsWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_TIME) { t => t.shuffleWriteTime })) + + // Go through the computed quantiles and cache the values that match the caching criteria. + computedQuantiles.quantiles.zipWithIndex + .filter { case (q, _) => quantiles.contains(q) && shouldCacheQuantile(q) } + .foreach { case (q, idx) => + val cached = new CachedQuantile(stageId, stageAttemptId, quantileToString(q), count, + executorDeserializeTime = computedQuantiles.executorDeserializeTime(idx), + executorDeserializeCpuTime = computedQuantiles.executorDeserializeCpuTime(idx), + executorRunTime = computedQuantiles.executorRunTime(idx), + executorCpuTime = computedQuantiles.executorCpuTime(idx), + resultSize = computedQuantiles.resultSize(idx), + jvmGcTime = computedQuantiles.jvmGcTime(idx), + resultSerializationTime = computedQuantiles.resultSerializationTime(idx), + gettingResultTime = computedQuantiles.gettingResultTime(idx), + schedulerDelay = computedQuantiles.schedulerDelay(idx), + peakExecutionMemory = computedQuantiles.peakExecutionMemory(idx), + memoryBytesSpilled = computedQuantiles.memoryBytesSpilled(idx), + diskBytesSpilled = computedQuantiles.diskBytesSpilled(idx), + + bytesRead = computedQuantiles.inputMetrics.bytesRead(idx), + recordsRead = computedQuantiles.inputMetrics.recordsRead(idx), + + bytesWritten = computedQuantiles.outputMetrics.bytesWritten(idx), + recordsWritten = computedQuantiles.outputMetrics.recordsWritten(idx), + + shuffleReadBytes = computedQuantiles.shuffleReadMetrics.readBytes(idx), + shuffleRecordsRead = computedQuantiles.shuffleReadMetrics.readRecords(idx), + shuffleRemoteBlocksFetched = + computedQuantiles.shuffleReadMetrics.remoteBlocksFetched(idx), + shuffleLocalBlocksFetched = computedQuantiles.shuffleReadMetrics.localBlocksFetched(idx), + shuffleFetchWaitTime = computedQuantiles.shuffleReadMetrics.fetchWaitTime(idx), + shuffleRemoteBytesRead = computedQuantiles.shuffleReadMetrics.remoteBytesRead(idx), + shuffleRemoteBytesReadToDisk = + computedQuantiles.shuffleReadMetrics.remoteBytesReadToDisk(idx), + shuffleTotalBlocksFetched = computedQuantiles.shuffleReadMetrics.totalBlocksFetched(idx), + + shuffleWriteBytes = computedQuantiles.shuffleWriteMetrics.writeBytes(idx), + shuffleWriteRecords = computedQuantiles.shuffleWriteMetrics.writeRecords(idx), + shuffleWriteTime = computedQuantiles.shuffleWriteMetrics.writeTime(idx)) + store.write(cached) + } + + Some(computedQuantiles) } + /** + * Whether to cache information about a specific metric quantile. We cache quantiles at every 0.05 + * step, which covers the default values used both in the API and in the stages page. + */ + private def shouldCacheQuantile(q: Double): Boolean = (math.round(q * 100) % 5) == 0 + + private def quantileToString(q: Double): String = math.round(q * 100).toString + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.info).toSeq.reverse + .max(maxTasks).asScala.map(_.toApi).toSeq.reverse } def taskList( @@ -219,18 +358,43 @@ private[spark] class AppStatusStore( offset: Int, length: Int, sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val (indexName, ascending) = sortBy match { + case v1.TaskSorting.ID => + (None, true) + case v1.TaskSorting.INCREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), true) + case v1.TaskSorting.DECREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), false) + } + taskList(stageId, stageAttemptId, offset, length, indexName, ascending) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: Option[String], + ascending: Boolean): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) val base = store.view(classOf[TaskDataWrapper]) val indexed = sortBy match { - case v1.TaskSorting.ID => + case Some(index) => + base.index(index).parent(stageKey) + + case _ => + // Sort by ID, which is the "stage" index. base.index("stage").first(stageKey).last(stageKey) - case v1.TaskSorting.INCREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) - case v1.TaskSorting.DECREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) - .reverse() } - indexed.skip(offset).max(length).asScala.map(_.info).toSeq + + val ordered = if (ascending) indexed else indexed.reverse() + ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + } + + def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { + val stageKey = Array(stageId, attemptId) + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey).last(stageKey) + .asScala.map { exec => (exec.executorId -> exec.info) }.toMap } def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { @@ -256,12 +420,6 @@ private[spark] class AppStatusStore( .map { t => (t.taskId, t) } .toMap - val stageKey = Array(stage.stageId, stage.attemptId) - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) - .last(stageKey).closeableIterator().asScala - .map { exec => (exec.executorId -> exec.info) } - .toMap - new v1.StageData( stage.status, stage.stageId, @@ -295,7 +453,7 @@ private[spark] class AppStatusStore( stage.rddIds, stage.accumulatorUpdates, Some(tasks), - Some(execs), + Some(executorSummary(stage.stageId, stage.attemptId)), stage.killedTasksSummary) } @@ -352,22 +510,3 @@ private[spark] object AppStatusStore { } } - -/** - * Helper for getting distributions from nested metric types. - */ -private abstract class MetricHelper[I, O]( - rawMetrics: Seq[v1.TaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: v1.TaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala new file mode 100644 index 000000000000..87f434daf487 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.status.api.v1.TaskData + +private[spark] object AppStatusUtils { + + private val TASK_FINISHED_STATES = Set("FAILED", "KILLED", "SUCCESS") + + private def isTaskFinished(task: TaskData): Boolean = { + TASK_FINISHED_STATES.contains(task.status) + } + + def schedulerDelay(task: TaskData): Long = { + if (isTaskFinished(task) && task.taskMetrics.isDefined && task.duration.isDefined) { + val m = task.taskMetrics.get + schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, + m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) + } else { + // The task is still running and the metrics like executorRunTime are not available. + 0L + } + } + + def gettingResultTime(task: TaskData): Long = { + gettingResultTime(task.launchTime.getTime(), fetchStart(task), task.duration.getOrElse(-1L)) + } + + def schedulerDelay( + launchTime: Long, + fetchStart: Long, + duration: Long, + deserializeTime: Long, + serializeTime: Long, + runTime: Long): Long = { + math.max(0, duration - runTime - deserializeTime - serializeTime - + gettingResultTime(launchTime, fetchStart, duration)) + } + + def gettingResultTime(launchTime: Long, fetchStart: Long, duration: Long): Long = { + if (fetchStart > 0) { + if (duration > 0) { + launchTime + duration - fetchStart + } else { + System.currentTimeMillis() - fetchStart + } + } else { + 0L + } + } + + private def fetchStart(task: TaskData): Long = { + if (task.resultFetchStart.isDefined) { + task.resultFetchStart.get.getTime() + } else { + -1 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 52e83f250d34..79e3f13b826c 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -20,8 +20,11 @@ package org.apache.spark.status import java.util.Date import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.{HashSet, TreeSet} import scala.collection.mutable.HashMap +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} @@ -119,7 +122,9 @@ private class LiveTask( import LiveEntityHelpers._ - private var recordedMetrics: v1.TaskMetrics = null + // The task metrics use a special value when no metrics have been reported. The special value is + // checked when calculating indexed values when writing to the store (see [[TaskDataWrapper]]). + private var metrics: v1.TaskMetrics = createMetrics(default = -1L) var errorMessage: Option[String] = None @@ -129,8 +134,8 @@ private class LiveTask( */ def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { if (metrics != null) { - val old = recordedMetrics - recordedMetrics = new v1.TaskMetrics( + val old = this.metrics + val newMetrics = createMetrics( metrics.executorDeserializeTime, metrics.executorDeserializeCpuTime, metrics.executorRunTime, @@ -141,73 +146,35 @@ private class LiveTask( metrics.memoryBytesSpilled, metrics.diskBytesSpilled, metrics.peakExecutionMemory, - new v1.InputMetrics( - metrics.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead), - new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten), - new v1.ShuffleReadMetrics( - metrics.shuffleReadMetrics.remoteBlocksFetched, - metrics.shuffleReadMetrics.localBlocksFetched, - metrics.shuffleReadMetrics.fetchWaitTime, - metrics.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead), - new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten, - metrics.shuffleWriteMetrics.writeTime, - metrics.shuffleWriteMetrics.recordsWritten)) - if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten) + + this.metrics = newMetrics + + // Only calculate the delta if the old metrics contain valid information, otherwise + // the new metrics are the delta. + if (old.executorDeserializeTime >= 0L) { + subtractMetrics(newMetrics, old) + } else { + newMetrics + } } else { null } } - /** - * Return a new TaskMetrics object containing the delta of the various fields of the given - * metrics objects. This is currently targeted at updating stage data, so it does not - * necessarily calculate deltas for all the fields. - */ - private def calculateMetricsDelta( - metrics: v1.TaskMetrics, - old: v1.TaskMetrics): v1.TaskMetrics = { - val shuffleWriteDelta = new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, - 0L, - metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) - - val shuffleReadDelta = new v1.ShuffleReadMetrics( - 0L, 0L, 0L, - metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk - - old.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) - - val inputDelta = new v1.InputMetrics( - metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) - - val outputDelta = new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) - - new v1.TaskMetrics( - 0L, 0L, - metrics.executorRunTime - old.executorRunTime, - metrics.executorCpuTime - old.executorCpuTime, - 0L, 0L, 0L, - metrics.memoryBytesSpilled - old.memoryBytesSpilled, - metrics.diskBytesSpilled - old.diskBytesSpilled, - 0L, - inputDelta, - outputDelta, - shuffleReadDelta, - shuffleWriteDelta) - } - override protected def doUpdate(): Any = { val duration = if (info.finished) { info.duration @@ -215,22 +182,48 @@ private class LiveTask( info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) } - val task = new v1.TaskData( + new TaskDataWrapper( info.taskId, info.index, info.attemptNumber, - new Date(info.launchTime), - if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, - Some(duration), - info.executorId, - info.host, - info.status, - info.taskLocality.toString(), + info.launchTime, + if (info.gettingResult) info.gettingResultTime else -1L, + duration, + weakIntern(info.executorId), + weakIntern(info.host), + weakIntern(info.status), + weakIntern(info.taskLocality.toString()), info.speculative, newAccumulatorInfos(info.accumulables), errorMessage, - Option(recordedMetrics)) - new TaskDataWrapper(task, stageId, stageAttemptId) + + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGcTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + metrics.peakExecutionMemory, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten, + + stageId, + stageAttemptId) } } @@ -262,6 +255,7 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE var totalShuffleRead = 0L var totalShuffleWrite = 0L var isBlacklisted = false + var blacklistedInStages: Set[Int] = TreeSet() var executorLogs = Map[String, String]() @@ -307,56 +301,27 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeTime), Option(removeReason), executorLogs, - memoryMetrics) + memoryMetrics, + blacklistedInStages) new ExecutorSummaryWrapper(info) } } -/** Metrics tracked per stage (both total and per executor). */ -private class MetricsTracker { - var executorRunTime = 0L - var executorCpuTime = 0L - var inputBytes = 0L - var inputRecords = 0L - var outputBytes = 0L - var outputRecords = 0L - var shuffleReadBytes = 0L - var shuffleReadRecords = 0L - var shuffleWriteBytes = 0L - var shuffleWriteRecords = 0L - var memoryBytesSpilled = 0L - var diskBytesSpilled = 0L - - def update(delta: v1.TaskMetrics): Unit = { - executorRunTime += delta.executorRunTime - executorCpuTime += delta.executorCpuTime - inputBytes += delta.inputMetrics.bytesRead - inputRecords += delta.inputMetrics.recordsRead - outputBytes += delta.outputMetrics.bytesWritten - outputRecords += delta.outputMetrics.recordsWritten - shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + - delta.shuffleReadMetrics.remoteBytesRead - shuffleReadRecords += delta.shuffleReadMetrics.recordsRead - shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten - shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten - memoryBytesSpilled += delta.memoryBytesSpilled - diskBytesSpilled += delta.diskBytesSpilled - } - -} - private class LiveExecutorStageSummary( stageId: Int, attemptId: Int, executorId: String) extends LiveEntity { + import LiveEntityHelpers._ + var taskTime = 0L var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 + var isBlacklisted = false - val metrics = new MetricsTracker() + var metrics = createMetrics(default = 0L) override protected def doUpdate(): Any = { val info = new v1.ExecutorStageSummary( @@ -364,16 +329,17 @@ private class LiveExecutorStageSummary( failedTasks, succeededTasks, killedTasks, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBytesRead + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, - metrics.diskBytesSpilled) + metrics.diskBytesSpilled, + isBlacklisted) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) } @@ -402,24 +368,28 @@ private class LiveStage extends LiveEntity { var firstLaunchTime = Long.MaxValue - val metrics = new MetricsTracker() + var localitySummary: Map[String, Long] = Map() + + var metrics = createMetrics(default = 0L) val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + var blackListedExecutors = new HashSet[String]() + // Used for cleanup of tasks after they reach the configured limit. Not written to the store. @volatile var cleaning = false var savedTasks = new AtomicInteger(0) def executorSummary(executorId: String): LiveExecutorStageSummary = { executorSummaries.getOrElseUpdate(executorId, - new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + new LiveExecutorStageSummary(info.stageId, info.attemptNumber, executorId)) } def toApi(): v1.StageData = { new v1.StageData( status, info.stageId, - info.attemptId, + info.attemptNumber, info.numTasks, activeTasks, @@ -435,14 +405,14 @@ private class LiveStage extends LiveEntity { info.completionTime.map(new Date(_)), info.failureReason, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.localBytesRead + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, @@ -459,13 +429,15 @@ private class LiveStage extends LiveEntity { } override protected def doUpdate(): Any = { - new StageDataWrapper(toApi(), jobIds) + new StageDataWrapper(toApi(), jobIds, localitySummary) } } private class LiveRDDPartition(val blockName: String) { + import LiveEntityHelpers._ + // Pointers used by RDDPartitionSeq. @volatile var prev: LiveRDDPartition = null @volatile var next: LiveRDDPartition = null @@ -485,7 +457,7 @@ private class LiveRDDPartition(val blockName: String) { diskUsed: Long): Unit = { value = new v1.RDDPartitionInfo( blockName, - storageLevel, + weakIntern(storageLevel), memoryUsed, diskUsed, executors) @@ -495,6 +467,8 @@ private class LiveRDDPartition(val blockName: String) { private class LiveRDDDistribution(exec: LiveExecutor) { + import LiveEntityHelpers._ + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L @@ -508,7 +482,7 @@ private class LiveRDDDistribution(exec: LiveExecutor) { def toApi(): v1.RDDDataDistribution = { if (lastUpdate == null) { lastUpdate = new v1.RDDDataDistribution( - exec.hostPort, + weakIntern(exec.hostPort), memoryUsed, exec.maxMemory - exec.memoryUsed, diskUsed, @@ -524,7 +498,9 @@ private class LiveRDDDistribution(exec: LiveExecutor) { private class LiveRDD(val info: RDDInfo) extends LiveEntity { - var storageLevel: String = info.storageLevel.description + import LiveEntityHelpers._ + + var storageLevel: String = weakIntern(info.storageLevel.description) var memoryUsed = 0L var diskUsed = 0L @@ -533,6 +509,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { private val distributions = new HashMap[String, LiveRDDDistribution]() + def setStorageLevel(level: String): Unit = { + this.storageLevel = weakIntern(level) + } + def partition(blockName: String): LiveRDDPartition = { partitions.getOrElseUpdate(blockName, { val part = new LiveRDDPartition(blockName) @@ -593,6 +573,9 @@ private class SchedulerPool(name: String) extends LiveEntity { private object LiveEntityHelpers { + private val stringInterner = Interners.newWeakInterner[String]() + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { accums .filter { acc => @@ -604,13 +587,119 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.orNull, + acc.name.map(weakIntern).orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } .toSeq } + /** String interning to reduce the memory usage. */ + def weakIntern(s: String): String = { + stringInterner.intern(s) + } + + // scalastyle:off argcount + def createMetrics( + executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, + executorRunTime: Long, + executorCpuTime: Long, + resultSize: Long, + jvmGcTime: Long, + resultSerializationTime: Long, + memoryBytesSpilled: Long, + diskBytesSpilled: Long, + peakExecutionMemory: Long, + inputBytesRead: Long, + inputRecordsRead: Long, + outputBytesWritten: Long, + outputRecordsWritten: Long, + shuffleRemoteBlocksFetched: Long, + shuffleLocalBlocksFetched: Long, + shuffleFetchWaitTime: Long, + shuffleRemoteBytesRead: Long, + shuffleRemoteBytesReadToDisk: Long, + shuffleLocalBytesRead: Long, + shuffleRecordsRead: Long, + shuffleBytesWritten: Long, + shuffleWriteTime: Long, + shuffleRecordsWritten: Long): v1.TaskMetrics = { + new v1.TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new v1.InputMetrics( + inputBytesRead, + inputRecordsRead), + new v1.OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new v1.ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new v1.ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten)) + } + // scalastyle:on argcount + + def createMetrics(default: Long): v1.TaskMetrics = { + createMetrics(default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default) + } + + /** Add m2 values to m1. */ + def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = addMetrics(m1, m2, 1) + + /** Subtract m2 values from m1. */ + def subtractMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = { + addMetrics(m1, m2, -1) + } + + private def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics, mult: Int): v1.TaskMetrics = { + createMetrics( + m1.executorDeserializeTime + m2.executorDeserializeTime * mult, + m1.executorDeserializeCpuTime + m2.executorDeserializeCpuTime * mult, + m1.executorRunTime + m2.executorRunTime * mult, + m1.executorCpuTime + m2.executorCpuTime * mult, + m1.resultSize + m2.resultSize * mult, + m1.jvmGcTime + m2.jvmGcTime * mult, + m1.resultSerializationTime + m2.resultSerializationTime * mult, + m1.memoryBytesSpilled + m2.memoryBytesSpilled * mult, + m1.diskBytesSpilled + m2.diskBytesSpilled * mult, + m1.peakExecutionMemory + m2.peakExecutionMemory * mult, + m1.inputMetrics.bytesRead + m2.inputMetrics.bytesRead * mult, + m1.inputMetrics.recordsRead + m2.inputMetrics.recordsRead * mult, + m1.outputMetrics.bytesWritten + m2.outputMetrics.bytesWritten * mult, + m1.outputMetrics.recordsWritten + m2.outputMetrics.recordsWritten * mult, + m1.shuffleReadMetrics.remoteBlocksFetched + m2.shuffleReadMetrics.remoteBlocksFetched * mult, + m1.shuffleReadMetrics.localBlocksFetched + m2.shuffleReadMetrics.localBlocksFetched * mult, + m1.shuffleReadMetrics.fetchWaitTime + m2.shuffleReadMetrics.fetchWaitTime * mult, + m1.shuffleReadMetrics.remoteBytesRead + m2.shuffleReadMetrics.remoteBytesRead * mult, + m1.shuffleReadMetrics.remoteBytesReadToDisk + + m2.shuffleReadMetrics.remoteBytesReadToDisk * mult, + m1.shuffleReadMetrics.localBytesRead + m2.shuffleReadMetrics.localBytesRead * mult, + m1.shuffleReadMetrics.recordsRead + m2.shuffleReadMetrics.recordsRead * mult, + m1.shuffleWriteMetrics.bytesWritten + m2.shuffleWriteMetrics.bytesWritten * mult, + m1.shuffleWriteMetrics.writeTime + m2.shuffleWriteMetrics.writeTime * mult, + m1.shuffleWriteMetrics.recordsWritten + m2.shuffleWriteMetrics.recordsWritten * mult) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index ed9bdc6e1e3c..d121068718b8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) @@ -157,6 +158,14 @@ private[v1] class NotFoundException(msg: String) extends WebApplicationException .build() ) +private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( + new ServiceUnavailableException(msg), + Response + .status(Response.Status.SERVICE_UNAVAILABLE) + .entity(ErrorWrapper(msg)) + .build() +) + private[v1] class BadParameterException(msg: String) extends WebApplicationException( new IllegalArgumentException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index bd4df07e7afc..974697890dd0 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -19,13 +19,13 @@ package org.apache.spark.status.api.v1 import java.io.OutputStream import java.util.{List => JList} import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs._ import javax.ws.rs.core.{MediaType, Response, StreamingOutput} import scala.util.control.NonFatal -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI +import org.apache.spark.{JobExecutionStatus, SparkContext} +import org.apache.spark.ui.UIUtils @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AbstractApplicationResource extends BaseAppResource { @@ -51,6 +51,29 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { @Path("executors") def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + @GET + @Path("executors/{executorId}/threads") + def threadDump(@PathParam("executorId") execId: String): Array[ThreadStackTrace] = withUI { ui => + if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) { + throw new BadParameterException( + s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.") + } + + val safeSparkContext = ui.sc.getOrElse { + throw new ServiceUnavailable("Thread dumps not available through the history server.") + } + + ui.store.asOption(ui.store.executorSummary(execId)) match { + case Some(executorSummary) if executorSummary.isActive => + val safeThreadDump = safeSparkContext.getExecutorThreadDump(execId).getOrElse { + throw new NotFoundException("No thread dump is available.") + } + safeThreadDump + case Some(_) => throw new BadParameterException("Executor is not active.") + case _ => throw new NotFoundException("Executor does not exist.") + } + } + @GET @Path("allexecutors") def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 3b879545b3d2..96249e4bfd5f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -87,7 +87,8 @@ private[v1] class StagesResource extends BaseAppResource { } } - ui.store.taskSummary(stageId, stageAttemptId, quantiles) + ui.store.taskSummary(stageId, stageAttemptId, quantiles).getOrElse( + throw new NotFoundException(s"No tasks reported metrics for $stageId / $stageAttemptId yet.")) } @GET diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 45eaf935fb08..971d7e90fa7b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -19,6 +19,8 @@ package org.apache.spark.status.api.v1 import java.lang.{Long => JLong} import java.util.Date +import scala.xml.{NodeSeq, Text} + import com.fasterxml.jackson.annotation.JsonIgnoreProperties import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -68,7 +70,8 @@ class ExecutorStageSummary private[spark]( val shuffleWrite : Long, val shuffleWriteRecords : Long, val memoryBytesSpilled : Long, - val diskBytesSpilled : Long) + val diskBytesSpilled : Long, + val isBlacklistedForStage: Boolean) class ExecutorSummary private[spark]( val id: String, @@ -94,7 +97,8 @@ class ExecutorSummary private[spark]( val removeTime: Option[Date], val removeReason: Option[String], val executorLogs: Map[String, String], - val memoryMetrics: Option[MemoryMetrics]) + val memoryMetrics: Option[MemoryMetrics], + val blacklistedInStages: Set[Int]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, @@ -261,6 +265,9 @@ class TaskMetricDistributions private[spark]( val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], + val gettingResultTime: IndexedSeq[Double], + val schedulerDelay: IndexedSeq[Double], + val peakExecutionMemory: IndexedSeq[Double], val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], @@ -311,3 +318,32 @@ class RuntimeInfo private[spark]( val javaVersion: String, val javaHome: String, val scalaVersion: String) + +case class StackTrace(elems: Seq[String]) { + override def toString: String = elems.mkString + + def html: NodeSeq = { + val withNewLine = elems.foldLeft(NodeSeq.Empty) { (acc, elem) => + if (acc.isEmpty) { + acc :+ Text(elem) + } else { + acc :+
    :+ Text(elem) + } + } + + withNewLine + } + + def mkString(start: String, sep: String, end: String): String = { + elems.mkString(start, sep, end) + } +} + +case class ThreadStackTrace( + val threadId: Long, + val threadName: String, + val threadState: Thread.State, + val stackTrace: StackTrace, + val blockedByThreadId: Option[Long], + val blockedByLock: String, + val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 1cfd30df4909..646cf25880e3 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,9 +17,11 @@ package org.apache.spark.status -import java.lang.{Integer => JInteger, Long => JLong} +import java.lang.{Long => JLong} +import java.util.Date import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ @@ -49,10 +51,10 @@ private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvi private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex - private[this] val id: String = info.id + private def id: String = info.id @JsonIgnore @KVIndex("active") - private[this] val active: Boolean = info.isActive + private def active: Boolean = info.isActive @JsonIgnore @KVIndex("host") val host: String = info.hostPort.split(":")(0) @@ -69,52 +71,281 @@ private[spark] class JobDataWrapper( val skippedStages: Set[Int]) { @JsonIgnore @KVIndex - private[this] val id: Int = info.jobId + private def id: Int = info.jobId + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } private[spark] class StageDataWrapper( val info: StageData, - val jobIds: Set[Int]) { + val jobIds: Set[Int], + @JsonDeserialize(contentAs = classOf[JLong]) + val locality: Map[String, Long]) { @JsonIgnore @KVIndex - def id: Array[Int] = Array(info.stageId, info.attemptId) + private[this] val id: Array[Int] = Array(info.stageId, info.attemptId) @JsonIgnore @KVIndex("stageId") - def stageId: Int = info.stageId + private def stageId: Int = info.stageId + @JsonIgnore @KVIndex("active") + private def active: Boolean = info.status == StageStatus.ACTIVE + + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) +} + +/** + * Tasks have a lot of indices that are used in a few different places. This object keeps logical + * names for these indices, mapped to short strings to save space when using a disk store. + */ +private[spark] object TaskIndexNames { + final val ACCUMULATORS = "acc" + final val ATTEMPT = "att" + final val DESER_CPU_TIME = "dct" + final val DESER_TIME = "des" + final val DISK_SPILL = "dbs" + final val DURATION = "dur" + final val ERROR = "err" + final val EXECUTOR = "exe" + final val HOST = "hst" + final val EXEC_CPU_TIME = "ect" + final val EXEC_RUN_TIME = "ert" + final val GC_TIME = "gc" + final val GETTING_RESULT_TIME = "grt" + final val INPUT_RECORDS = "ir" + final val INPUT_SIZE = "is" + final val LAUNCH_TIME = "lt" + final val LOCALITY = "loc" + final val MEM_SPILL = "mbs" + final val OUTPUT_RECORDS = "or" + final val OUTPUT_SIZE = "os" + final val PEAK_MEM = "pem" + final val RESULT_SIZE = "rs" + final val SCHEDULER_DELAY = "dly" + final val SER_TIME = "rst" + final val SHUFFLE_LOCAL_BLOCKS = "slbl" + final val SHUFFLE_READ_RECORDS = "srr" + final val SHUFFLE_READ_TIME = "srt" + final val SHUFFLE_REMOTE_BLOCKS = "srbl" + final val SHUFFLE_REMOTE_READS = "srby" + final val SHUFFLE_REMOTE_READS_TO_DISK = "srbd" + final val SHUFFLE_TOTAL_READS = "stby" + final val SHUFFLE_TOTAL_BLOCKS = "stbl" + final val SHUFFLE_WRITE_RECORDS = "swr" + final val SHUFFLE_WRITE_SIZE = "sws" + final val SHUFFLE_WRITE_TIME = "swt" + final val STAGE = "stage" + final val STATUS = "sta" + final val TASK_INDEX = "idx" + final val COMPLETION_TIME = "ct" } /** - * The task information is always indexed with the stage ID, since that is how the UI and API - * consume it. That means every indexed value has the stage ID and attempt ID included, aside - * from the actual data being indexed. + * Unlike other data types, the task data wrapper does not keep a reference to the API's TaskData. + * That is to save memory, since for large applications there can be a large number of these + * elements (by default up to 100,000 per stage), and every bit of wasted memory adds up. + * + * It also contains many secondary indices, which are used to sort data efficiently in the UI at the + * expense of storage space (and slower write times). */ private[spark] class TaskDataWrapper( - val info: TaskData, + // Storing this as an object actually saves memory; it's also used as the key in the in-memory + // store, so in that case you'd save the extra copy of the value here. + @KVIndexParam + val taskId: JLong, + @KVIndexParam(value = TaskIndexNames.TASK_INDEX, parent = TaskIndexNames.STAGE) + val index: Int, + @KVIndexParam(value = TaskIndexNames.ATTEMPT, parent = TaskIndexNames.STAGE) + val attempt: Int, + @KVIndexParam(value = TaskIndexNames.LAUNCH_TIME, parent = TaskIndexNames.STAGE) + val launchTime: Long, + val resultFetchStart: Long, + @KVIndexParam(value = TaskIndexNames.DURATION, parent = TaskIndexNames.STAGE) + val duration: Long, + @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) + val executorId: String, + @KVIndexParam(value = TaskIndexNames.HOST, parent = TaskIndexNames.STAGE) + val host: String, + @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) + val status: String, + @KVIndexParam(value = TaskIndexNames.LOCALITY, parent = TaskIndexNames.STAGE) + val taskLocality: String, + val speculative: Boolean, + val accumulatorUpdates: Seq[AccumulableInfo], + val errorMessage: Option[String], + + // The following is an exploded view of a TaskMetrics API object. This saves 5 objects + // (= 80 bytes of Java object overhead) per instance of this wrapper. If the first value + // (executorDeserializeTime) is -1L, it means the metrics for this task have not been + // recorded. + @KVIndexParam(value = TaskIndexNames.DESER_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeTime: Long, + @KVIndexParam(value = TaskIndexNames.DESER_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_RUN_TIME, parent = TaskIndexNames.STAGE) + val executorRunTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.RESULT_SIZE, parent = TaskIndexNames.STAGE) + val resultSize: Long, + @KVIndexParam(value = TaskIndexNames.GC_TIME, parent = TaskIndexNames.STAGE) + val jvmGcTime: Long, + @KVIndexParam(value = TaskIndexNames.SER_TIME, parent = TaskIndexNames.STAGE) + val resultSerializationTime: Long, + @KVIndexParam(value = TaskIndexNames.MEM_SPILL, parent = TaskIndexNames.STAGE) + val memoryBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.DISK_SPILL, parent = TaskIndexNames.STAGE) + val diskBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.PEAK_MEM, parent = TaskIndexNames.STAGE) + val peakExecutionMemory: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_SIZE, parent = TaskIndexNames.STAGE) + val inputBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_RECORDS, parent = TaskIndexNames.STAGE) + val inputRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_SIZE, parent = TaskIndexNames.STAGE) + val outputBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_RECORDS, parent = TaskIndexNames.STAGE) + val outputRecordsWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_LOCAL_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleLocalBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_TIME, parent = TaskIndexNames.STAGE) + val shuffleFetchWaitTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK, + parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesReadToDisk: Long, + val shuffleLocalBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_SIZE, parent = TaskIndexNames.STAGE) + val shuffleBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_TIME, parent = TaskIndexNames.STAGE) + val shuffleWriteTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsWritten: Long, + val stageId: Int, val stageAttemptId: Int) { - @JsonIgnore @KVIndex - def id: Long = info.taskId + def hasMetrics: Boolean = executorDeserializeTime >= 0 + + def toApi: TaskData = { + val metrics = if (hasMetrics) { + Some(new TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new InputMetrics( + inputBytesRead, + inputRecordsRead), + new OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten))) + } else { + None + } - @JsonIgnore @KVIndex("stage") - def stage: Array[Int] = Array(stageId, stageAttemptId) + new TaskData( + taskId, + index, + attempt, + new Date(launchTime), + if (resultFetchStart > 0L) Some(new Date(resultFetchStart)) else None, + if (duration > 0L) Some(duration) else None, + executorId, + host, + status, + taskLocality, + speculative, + accumulatorUpdates, + errorMessage, + metrics) + } + + @JsonIgnore @KVIndex(TaskIndexNames.STAGE) + private def stage: Array[Int] = Array(stageId, stageAttemptId) - @JsonIgnore @KVIndex("runtime") - def runtime: Array[AnyRef] = { - val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) - Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.SCHEDULER_DELAY, parent = TaskIndexNames.STAGE) + def schedulerDelay: Long = { + if (hasMetrics) { + AppStatusUtils.schedulerDelay(launchTime, resultFetchStart, duration, executorDeserializeTime, + resultSerializationTime, executorRunTime) + } else { + -1L + } } - @JsonIgnore @KVIndex("startTime") - def startTime: Array[AnyRef] = { - Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.GETTING_RESULT_TIME, parent = TaskIndexNames.STAGE) + def gettingResultTime: Long = { + if (hasMetrics) { + AppStatusUtils.gettingResultTime(launchTime, resultFetchStart, duration) + } else { + -1L + } } - @JsonIgnore @KVIndex("active") - def active: Boolean = info.duration.isEmpty + /** + * Sorting by accumulators is a little weird, and the previous behavior would generate + * insanely long keys in the index. So this implementation just considers the first + * accumulator and its String representation. + */ + @JsonIgnore @KVIndex(value = TaskIndexNames.ACCUMULATORS, parent = TaskIndexNames.STAGE) + private def accumulators: String = { + if (accumulatorUpdates.nonEmpty) { + val acc = accumulatorUpdates.head + s"${acc.name}:${acc.value}" + } else { + "" + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_READS, parent = TaskIndexNames.STAGE) + private def shuffleTotalReads: Long = { + if (hasMetrics) { + shuffleLocalBytesRead + shuffleRemoteBytesRead + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_BLOCKS, parent = TaskIndexNames.STAGE) + private def shuffleTotalBlocks: Long = { + if (hasMetrics) { + shuffleLocalBlocksFetched + shuffleRemoteBlocksFetched + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) + private def error: String = if (errorMessage.isDefined) errorMessage.get else "" + @JsonIgnore @KVIndex(value = TaskIndexNames.COMPLETION_TIME, parent = TaskIndexNames.STAGE) + private def completionTime: Long = launchTime + duration } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { @@ -134,10 +365,13 @@ private[spark] class ExecutorStageSummaryWrapper( val info: ExecutorStageSummary) { @JsonIgnore @KVIndex - val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + private val _id: Array[Any] = Array(stageId, stageAttemptId, executorId) @JsonIgnore @KVIndex("stage") - private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + private def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore + def id: Array[Any] = _id } @@ -203,3 +437,53 @@ private[spark] class AppSummary( def id: String = classOf[AppSummary].getName() } + +/** + * A cached view of a specific quantile for one stage attempt's metrics. + */ +private[spark] class CachedQuantile( + val stageId: Int, + val stageAttemptId: Int, + val quantile: String, + val taskCount: Long, + + // The following fields are an exploded view of a single entry for TaskMetricDistributions. + val executorDeserializeTime: Double, + val executorDeserializeCpuTime: Double, + val executorRunTime: Double, + val executorCpuTime: Double, + val resultSize: Double, + val jvmGcTime: Double, + val resultSerializationTime: Double, + val gettingResultTime: Double, + val schedulerDelay: Double, + val peakExecutionMemory: Double, + val memoryBytesSpilled: Double, + val diskBytesSpilled: Double, + + val bytesRead: Double, + val recordsRead: Double, + + val bytesWritten: Double, + val recordsWritten: Double, + + val shuffleReadBytes: Double, + val shuffleRecordsRead: Double, + val shuffleRemoteBlocksFetched: Double, + val shuffleLocalBlocksFetched: Double, + val shuffleFetchWaitTime: Double, + val shuffleRemoteBytesRead: Double, + val shuffleRemoteBytesReadToDisk: Double, + val shuffleTotalBlocksFetched: Double, + + val shuffleWriteBytes: Double, + val shuffleWriteRecords: Double, + val shuffleWriteTime: Double) { + + @KVIndex @JsonIgnore + def id: Array[Any] = Array(stageId, stageAttemptId, quantile) + + @KVIndex("stage") @JsonIgnore + def stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 2c3da0ee85e0..d4a59c33b974 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -18,7 +18,8 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap + +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi @@ -132,10 +133,17 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + /** + * The max cache size is hardcoded to 10000, since the size of a BlockManagerId + * object is about 48B, the total memory cost should be below 1MB which is feasible. + */ + val blockManagerIdCache = CacheBuilder.newBuilder() + .maximumSize(10000) + .build(new CacheLoader[BlockManagerId, BlockManagerId]() { + override def load(id: BlockManagerId) = id + }) def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) blockManagerIdCache.get(id) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 89a6a71a589a..8e8f7d197c9e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint( val futures = blockManagerInfo.values.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId", e) + logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", + e) 0 // zero blocks were removed } }.toSeq @@ -192,11 +193,16 @@ class BlockManagerMasterEndpoint( val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + val futures = requiredBlockManagers.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + + s"${bm.blockManagerId}", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index 353eac60df17..0bacc34cdfd9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -54,10 +54,9 @@ trait BlockReplicationPolicy { } object BlockReplicationUtils { - // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see + * minimizing space usage. Please see * here. * * @param n total number of indices @@ -65,7 +64,6 @@ object BlockReplicationUtils { * @param r random number generator * @return list of m random unique indices */ - // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a452..b31862323a89 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -90,7 +92,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } @@ -316,6 +323,7 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() @@ -324,7 +332,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -397,12 +406,33 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - bytesInFlight -= size + if (!localBlocks.contains(blockId)) { + bytesInFlight -= size + } if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { @@ -583,8 +613,8 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block successfully. * @param blockId block id * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e9694fdbca2d..adc406bb1c44 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -24,19 +24,15 @@ import scala.collection.mutable import sun.nio.ch.DirectBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging /** - * :: DeveloperApi :: * Storage information for each BlockManager. * * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this * class cannot mutate the source of the information. Accesses are not thread-safe. */ -@DeveloperApi -@deprecated("This class may be removed or made private in a future release.", "2.2.0") -class StorageStatus( +private[spark] class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, val maxOnHeapMem: Option[Long], @@ -44,9 +40,6 @@ class StorageStatus( /** * Internal representation of the blocks stored in this block manager. - * - * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. - * These collections should only be mutated through the add/update/removeBlock methods. */ private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] @@ -87,9 +80,6 @@ class StorageStatus( */ def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - /** Return the blocks that belong to the given RDD stored in this block manager. */ - def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty) - /** Add the given block to this storage status. If it already exists, overwrite it. */ private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { updateStorageInfo(blockId, blockStatus) @@ -101,46 +91,6 @@ class StorageStatus( } } - /** Update the given block in this storage status. If it doesn't already exist, add it. */ - private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { - addBlock(blockId, blockStatus) - } - - /** Remove the given block from this storage status. */ - private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { - updateStorageInfo(blockId, BlockStatus.empty) - blockId match { - case RDDBlockId(rddId, _) => - // Actually remove the block, if it exists - if (_rddBlocks.contains(rddId)) { - val removed = _rddBlocks(rddId).remove(blockId) - // If the given RDD has no more blocks left, remove the RDD - if (_rddBlocks(rddId).isEmpty) { - _rddBlocks.remove(rddId) - } - removed - } else { - None - } - case _ => - _nonRddBlocks.remove(blockId) - } - } - - /** - * Return whether the given block is stored in this block manager in O(1) time. - * - * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. - */ - def containsBlock(blockId: BlockId): Boolean = { - blockId match { - case RDDBlockId(rddId, _) => - _rddBlocks.get(rddId).exists(_.contains(blockId)) - case _ => - _nonRddBlocks.contains(blockId) - } - } - /** * Return the given block stored in this block manager in O(1) time. * @@ -155,37 +105,12 @@ class StorageStatus( } } - /** - * Return the number of blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.blocks.size`, which is O(blocks) time. - */ - def numBlocks: Int = _nonRddBlocks.size + numRddBlocks - - /** - * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. - */ - def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - - /** - * Return the number of blocks that belong to the given RDD in O(1) time. - * - * @note This is much faster than `this.rddBlocksById(rddId).size`, which is - * O(blocks in this RDD) time. - */ - def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) - /** Return the max memory can be used by this block manager. */ def maxMem: Long = maxMemory /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed - /** Return the memory used by caching RDDs */ - def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) - /** Return the memory used by this block manager. */ def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) @@ -220,15 +145,9 @@ class StorageStatus( /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) - /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) - /** * Update the relevant storage info, taking into account any existing status for this block. */ @@ -295,40 +214,4 @@ private[spark] object StorageUtils extends Logging { cleaner.clean() } } - - /** - * Update the given list of RDDInfo with the given list of storage statuses. - * This method overwrites the old values stored in the RDDInfo's. - */ - def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { - rddInfos.foreach { rddInfo => - val rddId = rddInfo.id - // Assume all blocks belonging to the same RDD have the same storage level - val storageLevel = statuses - .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) - val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum - val memSize = statuses.map(_.memUsedByRdd(rddId)).sum - val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - - rddInfo.storageLevel = storageLevel - rddInfo.numCachedPartitions = numCachedPartitions - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - } - } - - /** - * Return a mapping from block ID to its locations for each block that belongs to the given RDD. - */ - def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] - statuses.foreach { status => - status.rddBlocksById(rddId).foreach { case (bid, _) => - val location = status.blockManagerId.hostPort - blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location - } - } - blockLocations - } - } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 17f7a69ad6ba..4cc5bcb7f9ba 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -162,7 +162,7 @@ private[spark] class MemoryStore( } /** - * Attempt to put the given block in memory store as values. + * Attempt to put the given block in memory store as values or bytes. * * It's possible that the iterator is too large to materialize and store in memory. To avoid * OOM exceptions, this method will gradually unroll the iterator while periodically checking @@ -170,18 +170,24 @@ private[spark] class MemoryStore( * temporary unroll memory used during the materialization is "transferred" to storage memory, * so we won't acquire more memory than is actually needed to store the block. * - * @return in case of success, the estimated size of the stored data. In case of failure, return - * an iterator containing the values of the block. The returned iterator will be backed - * by the combination of the partially-unrolled block and the remaining elements of the - * original input iterator. The caller must either fully consume this iterator or call - * `close()` on it in order to free the storage memory consumed by the partially-unrolled - * block. + * @param blockId The block id. + * @param values The values which need be stored. + * @param classTag the [[ClassTag]] for the block. + * @param memoryMode The values saved memory mode(ON_HEAP or OFF_HEAP). + * @param valuesHolder A holder that supports storing record of values into memory store as + * values or bytes. + * @return if the block is stored successfully, return the stored data size. Else return the + * memory has reserved for unrolling the block (There are two reasons for store failed: + * First, the block is partially-unrolled; second, the block is entirely unrolled and + * the actual stored data size is larger than reserved, but we can't request extra + * memory). */ - private[storage] def putIteratorAsValues[T]( + private def putIterator[T]( blockId: BlockId, values: Iterator[T], - classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { - + classTag: ClassTag[T], + memoryMode: MemoryMode, + valuesHolder: ValuesHolder[T]): Either[Long, Long] = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") // Number of elements unrolled so far @@ -198,12 +204,10 @@ private[spark] class MemoryStore( val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L - // Underlying vector for unrolling the block - var vector = new SizeTrackingVector[T]()(classTag) // Request enough memory to begin unrolling keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -214,14 +218,14 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold periodically while (values.hasNext && keepUnrolling) { - vector += values.next() + valuesHolder.storeValue(values.next()) if (elementsUnrolled % memoryCheckPeriod == 0) { + val currentSize = valuesHolder.estimatedSize() // If our vector's size has exceeded the threshold, request more memory - val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest } @@ -232,78 +236,86 @@ private[spark] class MemoryStore( elementsUnrolled += 1 } + // Make sure that we have enough memory to store the block. By this point, it is possible that + // the block's actual memory usage has exceeded the unroll memory by a small amount, so we + // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { - // We successfully unrolled the entirety of this block - val arrayValues = vector.toArray - vector = null - val entry = - new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) - val size = entry.size - def transferUnrollToStorage(amount: Long): Unit = { + val entryBuilder = valuesHolder.getBuilder() + val size = entryBuilder.preciseSize + if (size > unrollMemoryUsedByThisBlock) { + val amountToRequest = size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } + + if (keepUnrolling) { + val entry = entryBuilder.build() // Synchronize so that transfer is atomic memoryManager.synchronized { - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount) - val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP) + releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) + val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) assert(success, "transferring unroll memory to storage memory failed") } - } - // Acquire storage memory if necessary to store this block in memory. - val enoughStorageMemory = { - if (unrollMemoryUsedByThisBlock <= size) { - val acquiredExtra = - memoryManager.acquireStorageMemory( - blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP) - if (acquiredExtra) { - transferUnrollToStorage(unrollMemoryUsedByThisBlock) - } - acquiredExtra - } else { // unrollMemoryUsedByThisBlock > size - // If this task attempt already owns more unroll memory than is necessary to store the - // block, then release the extra memory that will not be used. - val excessUnrollMemory = unrollMemoryUsedByThisBlock - size - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory) - transferUnrollToStorage(size) - true - } - } - if (enoughStorageMemory) { + entries.synchronized { entries.put(blockId, entry) } - logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(size) + + logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(blockId, + Utils.bytesToString(entry.size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + Right(entry.size) } else { - assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, - "released too much unroll memory") + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, entryBuilder.preciseSize) + Left(unrollMemoryUsedByThisBlock) + } + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, valuesHolder.estimatedSize()) + Left(unrollMemoryUsedByThisBlock) + } + } + + /** + * Attempt to put the given block in memory store as values. + * + * @return in case of success, the estimated size of the stored data. In case of failure, return + * an iterator containing the values of the block. The returned iterator will be backed + * by the combination of the partially-unrolled block and the remaining elements of the + * original input iterator. The caller must either fully consume this iterator or call + * `close()` on it in order to free the storage memory consumed by the partially-unrolled + * block. + */ + private[storage] def putIteratorAsValues[T]( + blockId: BlockId, + values: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + + val valuesHolder = new DeserializedValuesHolder[T](classTag) + + putIterator(blockId, values, classTag, MemoryMode.ON_HEAP, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + val unrolledIterator = if (valuesHolder.vector != null) { + valuesHolder.vector.iterator + } else { + valuesHolder.arrayValues.toIterator + } + Left(new PartiallyUnrolledIterator( this, MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, - unrolled = arrayValues.toIterator, - rest = Iterator.empty)) - } - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, vector.estimateSize()) - Left(new PartiallyUnrolledIterator( - this, - MemoryMode.ON_HEAP, - unrollMemoryUsedByThisBlock, - unrolled = vector.iterator, - rest = values)) + unrolled = unrolledIterator, + rest = values)) } } /** * Attempt to put the given block in memory store as bytes. * - * It's possible that the iterator is too large to materialize and store in memory. To avoid - * OOM exceptions, this method will gradually unroll the iterator while periodically checking - * whether there is enough free memory. If the block is successfully materialized, then the - * temporary unroll memory used during the materialization is "transferred" to storage memory, - * so we won't acquire more memory than is actually needed to store the block. - * * @return in case of success, the estimated size of the stored data. In case of failure, * return a handle which allows the caller to either finish the serialization by * spilling to disk or to deserialize the partially-serialized block and reconstruct @@ -319,25 +331,8 @@ private[spark] class MemoryStore( require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - val allocator = memoryMode match { - case MemoryMode.ON_HEAP => ByteBuffer.allocate _ - case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ - } - - // Whether there is still enough memory for us to continue unrolling this block - var keepUnrolling = true - // Number of elements unrolled so far - var elementsUnrolled = 0L - // How often to check whether we need to request more memory - val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) - // Memory to request as a multiple of current bbos size - val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold - // Keep track of unroll memory used by this particular block / putIterator() operation - var unrollMemoryUsedByThisBlock = 0L - // Underlying buffer for unrolling the block - val redirectableStream = new RedirectableOutputStream val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + s"is too large to be set as chunk size. Chunk size has been capped to " + @@ -346,85 +341,22 @@ private[spark] class MemoryStore( } else { initialMemoryThreshold.toInt } - val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) - redirectableStream.setOutputStream(bbos) - val serializationStream: SerializationStream = { - val autoPick = !blockId.isInstanceOf[StreamBlockId] - val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) - } - // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) + val valuesHolder = new SerializedValuesHolder[T](blockId, chunkSize, classTag, + memoryMode, serializerManager) - if (!keepUnrolling) { - logWarning(s"Failed to reserve initial memory threshold of " + - s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") - } else { - unrollMemoryUsedByThisBlock += initialMemoryThreshold - } - - def reserveAdditionalMemoryIfNecessary(): Unit = { - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - // Unroll this block safely, checking whether we have exceeded our threshold - while (values.hasNext && keepUnrolling) { - serializationStream.writeObject(values.next())(classTag) - elementsUnrolled += 1 - if (elementsUnrolled % memoryCheckPeriod == 0) { - reserveAdditionalMemoryIfNecessary() - } - } - - // Make sure that we have enough memory to store the block. By this point, it is possible that - // the block's actual memory usage has exceeded the unroll memory by a small amount, so we - // perform one final call to attempt to allocate additional memory if necessary. - if (keepUnrolling) { - serializationStream.close() - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - if (keepUnrolling) { - val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) - // Synchronize so that transfer is atomic - memoryManager.synchronized { - releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) - val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) - assert(success, "transferring unroll memory to storage memory failed") - } - entries.synchronized { - entries.put(blockId, entry) - } - logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(entry.size), - Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(entry.size) - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, bbos.size) - Left( - new PartiallySerializedBlock( + putIterator(blockId, values, classTag, memoryMode, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + Left(new PartiallySerializedBlock( this, serializerManager, blockId, - serializationStream, - redirectableStream, + valuesHolder.serializationStream, + valuesHolder.redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos, + valuesHolder.bbos, values, classTag)) } @@ -702,6 +634,94 @@ private[spark] class MemoryStore( } } +private trait MemoryEntryBuilder[T] { + def preciseSize: Long + def build(): MemoryEntry[T] +} + +private trait ValuesHolder[T] { + def storeValue(value: T): Unit + def estimatedSize(): Long + + /** + * Note: After this method is called, the ValuesHolder is invalid, we can't store data and + * get estimate size again. + * @return a MemoryEntryBuilder which is used to build a memory entry and get the stored data + * size. + */ + def getBuilder(): MemoryEntryBuilder[T] +} + +/** + * A holder for storing the deserialized values. + */ +private class DeserializedValuesHolder[T] (classTag: ClassTag[T]) extends ValuesHolder[T] { + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[T]()(classTag) + var arrayValues: Array[T] = null + + override def storeValue(value: T): Unit = { + vector += value + } + + override def estimatedSize(): Long = { + vector.estimateSize() + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + arrayValues = vector.toArray + vector = null + + override val preciseSize: Long = SizeEstimator.estimate(arrayValues) + + override def build(): MemoryEntry[T] = + DeserializedMemoryEntry[T](arrayValues, preciseSize, classTag) + } +} + +/** + * A holder for storing the serialized values. + */ +private class SerializedValuesHolder[T]( + blockId: BlockId, + chunkSize: Int, + classTag: ClassTag[T], + memoryMode: MemoryMode, + serializerManager: SerializerManager) extends ValuesHolder[T] { + val allocator = memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + + val redirectableStream = new RedirectableOutputStream + val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) + redirectableStream.setOutputStream(bbos) + val serializationStream: SerializationStream = { + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + } + + override def storeValue(value: T): Unit = { + serializationStream.writeObject(value)(classTag) + } + + override def estimatedSize(): Long = { + bbos.size + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + serializationStream.close() + + override def preciseSize(): Long = bbos.size + + override def build(): MemoryEntry[T] = + SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) + } +} + /** * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0adeb4058b6e..d6a025a6f12d 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -343,12 +343,14 @@ private[spark] object JettyUtils extends Logging { -1, connectionFactories: _*) connector.setPort(port) - connector.start() + connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) - connector.setHost(hostName) + + connector.start() // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ba798df13c95..02cf19e00ecd 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -224,6 +224,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (showVisualization) vizHeaderNodes else Seq.empty} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {appName} - {title} @@ -265,6 +266,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {title} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 43adab7a35d6..902eb92b854f 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -48,10 +48,50 @@ private[ui] class EnvironmentPage( classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content = -

    Runtime Information

    {runtimeInformationTable} -

    Spark Properties

    {sparkPropertiesTable} -

    System Properties

    {systemPropertiesTable} -

    Classpath Entries

    {classpathEntriesTable} + +

    + + Runtime Information +

    +
    +
    + {runtimeInformationTable} +
    + +

    + + Spark Properties +

    +
    +
    + {sparkPropertiesTable} +
    + +

    + + System Properties +

    +
    +
    + {systemPropertiesTable} +
    + +

    + + Classpath Entries +

    +
    +
    + {classpathEntriesTable} +
    UIUtils.headerSparkPage("Environment", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f4686ea3cf91..9bb026c60565 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -41,17 +40,7 @@ private[ui] class ExecutorThreadDumpPage( val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => - val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 - val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 - if (v1 == v2) { - threadTrace1.threadName.toLowerCase(Locale.ROOT) < - threadTrace2.threadName.toLowerCase(Locale.ROOT) - } else { - v1 > v2 - } - }.map { thread => + val dumpRows = threadDump.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { case Some(_) => @@ -71,7 +60,7 @@ private[ui] class ExecutorThreadDumpPage( {thread.threadName} {thread.threadState} {blockedBy}{heldLocks} - {thread.stackTrace} + {thread.stackTrace.html} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 37e3b3b304a6..2b0f4acbac72 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -36,6 +36,9 @@ import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("") { + + import ApiHelper._ + private val JOBS_LEGEND =
    val jobId = job.jobId val status = job.status - val displayJobDescription = - if (job.description.isEmpty) { - job.name - } else { - UIUtils.makeDescription(job.description.get, "", plainText = true).text - } + val (_, lastStageDescription) = lastStageNameAndDescription(store, job) + val jobDescription = UIUtils.makeDescription(lastStageDescription, "", plainText = true).text + val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -82,7 +82,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We // The timeline library treats contents as HTML, so we have to escape them. We need to add // extra layers of escaping in order to embed this in a Javascript string literal. - val escapedDesc = Utility.escape(displayJobDescription) + val escapedDesc = Utility.escape(jobDescription) val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc) val jobEventJsonAsStr = s""" @@ -365,16 +365,43 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We store.executorList(false), startTime) if (shouldShowActiveJobs) { - content ++=

    Active Jobs ({activeJobs.size})

    ++ - activeJobsTable + content ++= + +

    + + Active Jobs ({activeJobs.size}) +

    +
    ++ +
    + {activeJobsTable} +
    } if (shouldShowCompletedJobs) { - content ++=

    Completed Jobs ({completedJobNumStr})

    ++ - completedJobsTable + content ++= + +

    + + Completed Jobs ({completedJobNumStr}) +

    +
    ++ +
    + {completedJobsTable} +
    } if (shouldShowFailedJobs) { - content ++=

    Failed Jobs ({failedJobs.size})

    ++ - failedJobsTable + content ++= + +

    + + Failed Jobs ({failedJobs.size}) +

    +
    ++ +
    + {failedJobsTable} +
    } val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + @@ -405,6 +432,8 @@ private[ui] class JobDataSource( sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { + import ApiHelper._ + // Convert JobUIData to JobTableRowData which contains the final contents to show in the table // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) @@ -429,15 +458,16 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), - basePath, plainText = false) + val (lastStageName, lastStageDescription) = lastStageNameAndDescription(store, jobData) + + val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - jobData.name, - jobData.description.getOrElse(jobData.name), + lastStageName, + lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index b1e343451e28..4658aa1cea3f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -19,42 +19,20 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, NodeSeq} +import scala.xml.{Attribute, Elem, Node, NodeSeq, Null, Text} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.status.PoolData -import org.apache.spark.status.api.v1._ +import org.apache.spark.status.{AppSummary, PoolData} +import org.apache.spark.status.api.v1.{StageData, StageStatus} import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc + private val subPath = "stages" private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - val allStages = parent.store.stageList(null) - - val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) - val pendingStages = allStages.filter(_.status == StageStatus.PENDING) - val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) - val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - - val numFailedStages = failedStages.size - val subPath = "stages" - - val activeStagesTable = - new StageTableBase(parent.store, request, activeStages, "active", "activeStage", - parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) - val pendingStagesTable = - new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val completedStagesTable = - new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val failedStagesTable = - new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", - parent.basePath, subPath, parent.isFairScheduler, false, true) - // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( @@ -63,80 +41,121 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { }.toMap val poolTable = new PoolTable(pools, parent) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val allStatuses = Seq(StageStatus.ACTIVE, StageStatus.PENDING, StageStatus.COMPLETE, + StageStatus.SKIPPED, StageStatus.FAILED) + val allStages = parent.store.stageList(null) val appSummary = parent.store.appSummary() - val completedStageNumStr = if (appSummary.numCompletedStages == completedStages.size) { - s"${appSummary.numCompletedStages}" - } else { - s"${appSummary.numCompletedStages}, only showing ${completedStages.size}" - } + + val (summaries, tables) = allStatuses.map( + summaryAndTableForStatus(allStages, appSummary, _, request)).unzip val summary: NodeSeq =
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - Pending Stages: - {pendingStages.size} -
    • - } - } - { - if (shouldShowCompletedStages) { -
    • - Completed Stages: - {completedStageNumStr} -
    • - } - } - { - if (shouldShowFailedStages) { -
    • - Failed Stages: - {numFailedStages} -
    • - } - } + {summaries.flatten}
    - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { -

    Fair Scheduler Pools ({pools.size})

    ++ poolTable.toNodeSeq - } else { - Seq.empty[Node] - } + val poolsDescription = if (sc.isDefined && isFairScheduler) { + +

    + + Fair Scheduler Pools ({pools.size}) +

    +
    ++ +
    + {poolTable.toNodeSeq} +
    + } else { + Seq.empty[Node] } - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq + + val content = summary ++ poolsDescription ++ tables.flatten.flatten + + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def summaryAndTableForStatus( + allStages: Seq[StageData], + appSummary: AppSummary, + status: StageStatus, + request: HttpServletRequest): (Option[Elem], Option[NodeSeq]) = { + val stages = if (status == StageStatus.FAILED) { + allStages.filter(_.status == status).reverse + } else { + allStages.filter(_.status == status) } - if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingStages.size})

    ++ - pendingStagesTable.toNodeSeq + + if (stages.isEmpty) { + (None, None) + } else { + val killEnabled = status == StageStatus.ACTIVE && parent.killEnabled + val isFailedStage = status == StageStatus.FAILED + + val stagesTable = + new StageTableBase(parent.store, request, stages, statusName(status), stageTag(status), + parent.basePath, subPath, parent.isFairScheduler, killEnabled, isFailedStage) + val stagesSize = stages.size + (Some(summary(appSummary, status, stagesSize)), + Some(table(appSummary, status, stagesTable, stagesSize))) } - if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStageNumStr})

    ++ - completedStagesTable.toNodeSeq + } + + private def statusName(status: StageStatus): String = status match { + case StageStatus.ACTIVE => "active" + case StageStatus.COMPLETE => "completed" + case StageStatus.FAILED => "failed" + case StageStatus.PENDING => "pending" + case StageStatus.SKIPPED => "skipped" + } + + private def stageTag(status: StageStatus): String = s"${statusName(status)}Stage" + + private def headerDescription(status: StageStatus): String = statusName(status).capitalize + + private def summaryContent(appSummary: AppSummary, status: StageStatus, size: Int): String = { + if (status == StageStatus.COMPLETE && appSummary.numCompletedStages != size) { + s"${appSummary.numCompletedStages}, only showing $size" + } else { + s"$size" } - if (shouldShowFailedStages) { - content ++=

    Failed Stages ({numFailedStages})

    ++ - failedStagesTable.toNodeSeq + } + + private def summary(appSummary: AppSummary, status: StageStatus, size: Int): Elem = { + val summary = +
  • + + {headerDescription(status)} Stages: + + {summaryContent(appSummary, status, size)} +
  • + + if (status == StageStatus.COMPLETE) { + summary % Attribute(None, "id", Text("completed-summary"), Null) + } else { + summary } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def table( + appSummary: AppSummary, + status: StageStatus, + stagesTable: StageTableBase, + size: Int): NodeSeq = { + val classSuffix = s"${statusName(status).capitalize}Stages" + +

    + + {headerDescription(status)} Stages ({summaryContent(appSummary, status, size)}) +

    +
    ++ +
    + {stagesTable.toNodeSeq} +
    } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 41d42b52430a..0ff64f053f37 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -87,7 +87,9 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { } private def createExecutorTable(stage: StageData) : Seq[Node] = { - stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) + + executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k)) @@ -134,7 +136,15 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { {Utils.bytesToString(v.diskBytesSpilled)} }} - {executor.map(_.isBlacklisted).getOrElse(false)} + { + if (executor.map(_.isBlacklisted).getOrElse(false)) { + for application + } else if (v.isBlacklistedForStage) { + for stage + } else { + false + } + } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 740f12e7d13d..46f2a76cc651 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -201,7 +201,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the // stage or if the stage information has been garbage collected - store.stageData(stageId).lastOption.getOrElse { + store.asOption(store.lastStageAttempt(stageId)).getOrElse { new v1.StageData( v1.StageStatus.PENDING, stageId, @@ -336,28 +336,82 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, store.executorList(false), appStartTime) - content ++= UIUtils.showDagVizForJob( - jobId, store.operationGraphForJob(jobId)) + val operationGraphContent = store.asOption(store.operationGraphForJob(jobId)) match { + case Some(operationGraph) => UIUtils.showDagVizForJob(jobId, operationGraph) + case None => +
    +

    No DAG visualization information to display for job {jobId}

    +
    + } + content ++= operationGraphContent if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq + content ++= + +

    + + Active Stages ({activeStages.size}) +

    +
    ++ +
    + {activeStagesTable.toNodeSeq} +
    } if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

    + + Pending Stages ({pendingOrSkippedStages.size}) +

    +
    ++ +
    + {pendingOrSkippedStagesTable.toNodeSeq} +
    } if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq + content ++= + +

    + + Completed Stages ({completedStages.size}) +

    +
    ++ +
    + {completedStagesTable.toNodeSeq} +
    } if (shouldShowSkippedStages) { - content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

    + + Skipped Stages ({pendingOrSkippedStages.size}) +

    +
    ++ +
    + {pendingOrSkippedStagesTable.toNodeSeq} +
    } if (shouldShowFailedStages) { - content ++=

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq + content ++= + +

    + + Failed Stages ({failedStages.size}) +

    +
    ++ +
    + {failedStagesTable.toNodeSeq} +
    } UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 99eab1b2a27d..ff1b75e5c506 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -34,10 +34,10 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) val killEnabled = parent.killEnabled def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def getSparkUser: String = parent.getSparkUser diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 98fbd7aceaa1..a3e1f13782e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -51,7 +51,18 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { val poolTable = new PoolTable(Map(pool -> uiPool), parent) var content =

    Summary

    ++ poolTable.toNodeSeq if (activeStages.nonEmpty) { - content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq + content ++= + +

    + + Active Stages ({activeStages.size}) +

    +
    ++ +
    + {activeStagesTable.toNodeSeq} +
    } UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 11a6a3434497..ac83de10f923 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -19,25 +19,23 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder import java.util.Date +import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} -import scala.xml.{Elem, Node, Unparsed} +import scala.xml.{Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkConf -import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status._ import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.Utils /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import ApiHelper._ - import StagePage._ private val TIMELINE_LEGEND = {
    @@ -67,17 +65,17 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { - val localities = taskList.map(_.taskLocality) - val localityCounts = localities.groupBy(identity).mapValues(_.size) + private def getLocalitySummaryString(localitySummary: Map[String, Long]): String = { val names = Map( TaskLocality.PROCESS_LOCAL.toString() -> "Process local", TaskLocality.NODE_LOCAL.toString() -> "Node local", TaskLocality.RACK_LOCAL.toString() -> "Rack local", TaskLocality.ANY.toString() -> "Any") - val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - s"${names(locality)}: $count" - } + val localityNamesAndCounts = names.flatMap { case (key, name) => + localitySummary.get(key).map { count => + s"$name: $count" + } + }.toSeq localityNamesAndCounts.sorted.mkString("; ") } @@ -108,7 +106,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" val stageData = parent.store - .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false)) .getOrElse { val content =
    @@ -117,8 +115,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } - val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq - if (tasks.isEmpty) { + val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) + + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + if (totalTasks == 0) { val content =

    Summary Metrics

    No tasks have started yet @@ -127,18 +128,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } + val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { + val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${tasks.size}" + s"$totalTasks, showing ${storedTasks}" } - val externalAccumulables = stageData.accumulatorUpdates - val hasAccumulators = externalAccumulables.size > 0 - val summary =
      @@ -148,7 +145,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    • Locality Level Summary: - {getLocalitySummaryString(stageData, tasks)} + {getLocalitySummaryString(localitySummary)}
    • {if (hasInput(stageData)) {
    • @@ -261,12 +258,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { - {acc.name}{acc.value} + if (acc.name != null && acc.value != null) { + {acc.name}{acc.value} + } else { + Nil + } } val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - externalAccumulables.toSeq) + stageData.accumulatorUpdates.toSeq) val page: Int = { // If the user has changed to a larger page size, then go to page 1 in order to avoid @@ -280,16 +281,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( - parent.conf, + stageData, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", - tasks, - hasAccumulators, - hasInput(stageData), - hasOutput(stageData), - hasShuffleRead(stageData), - hasShuffleWrite(stageData), - hasBytesSpilled(stageData), currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -320,217 +314,155 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We | } |}); """.stripMargin - } + } } - val taskIdsInPage = if (taskTable == null) Set.empty[Long] - else taskTable.dataSource.slicedTaskIds - - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) + val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, + Array(0, 0.25, 0.5, 0.75, 1.0)) - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } else { - def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { - Distribution(data).get.getQuantiles() - } - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - getDistributionQuantiles(times).map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + val summaryTable = metricsSummary.map { metrics => + def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { millis => + {UIUtils.formatDuration(millis.toLong)} } + } - val deserializationTimes = validTasks.map { task => - task.taskMetrics.get.executorDeserializeTime.toDouble - } - val deserializationQuantiles = - - - Task Deserialization Time - - +: getFormattedTimeQuantiles(deserializationTimes) - - val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) - val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) - - val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) - val gcQuantiles = - - GC Time - - +: getFormattedTimeQuantiles(gcTimes) - - val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) - val serializationQuantiles = - - - Result Serialization Time - - +: getFormattedTimeQuantiles(serializationTimes) - - val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) - val gettingResultQuantiles = - - - Getting Result Time - - +: - getFormattedTimeQuantiles(gettingResultTimes) - - val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) - val peakExecutionMemoryQuantiles = { - - - Peak Execution Memory - - +: getFormattedSizeQuantiles(peakExecutionMemory) + def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { size => + {Utils.bytesToString(size.toLong)} } + } - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { task => - getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble - } - val schedulerDelayTitle = Scheduler Delay - val schedulerDelayQuantiles = schedulerDelayTitle +: - getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) - : Seq[Elem] = { - val recordDist = getDistributionQuantiles(records).iterator - getDistributionQuantiles(data).map(d => - {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} - ) + def sizeQuantilesWithRecords( + data: IndexedSeq[Double], + records: IndexedSeq[Double]) : Seq[Node] = { + data.zip(records).map { case (d, r) => + {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} } + } - val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) - val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) - val inputQuantiles = Input Size / Records +: - getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) + def titleCell(title: String, tooltip: String): Seq[Node] = { + + + {title} + + + } - val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) - val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) - val outputQuantiles = Output Size / Records +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + def simpleTitleCell(title: String): Seq[Node] = {title} - val shuffleReadBlockedTimes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - - - Shuffle Read Blocked Time - - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) - - val shuffleReadTotalSizes = validTasks.map { task => - totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble - } - val shuffleReadTotalRecords = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - - - Shuffle Read Size / Records - - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - - val shuffleReadRemoteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble - } - val shuffleReadRemoteQuantiles = - - - Shuffle Remote Reads - - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) - - val shuffleWriteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationQuantiles = titleCell("Task Deserialization Time", + ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - val shuffleWriteRecords = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - val shuffleWriteQuantiles = Shuffle Write Size / Records +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) - val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val serializationQuantiles = titleCell("Result Serialization Time", + ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) - val diskBytesSpilledQuantiles = Shuffle spill (disk) +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) + val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ + timeQuantiles(metrics.gettingResultTime) - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", + ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) + + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ + timeQuantiles(metrics.schedulerDelay) + + def inputQuantiles: Seq[Node] = { + simpleTitleCell("Input Size / Records") ++ + sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) + } + + def outputQuantiles: Seq[Node] = { + simpleTitleCell("Output Size / Records") ++ + sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten) } + def shuffleReadBlockedQuantiles: Seq[Node] = { + titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ + timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) + } + + def shuffleReadTotalQuantiles: Seq[Node] = { + titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ + sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, + metrics.shuffleReadMetrics.readRecords) + } + + def shuffleReadRemoteQuantiles: Seq[Node] = { + titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ + sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) + } + + def shuffleWriteQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle Write Size / Records") ++ + sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, + metrics.shuffleWriteMetrics.writeRecords) + } + + def memoryBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) + } + + def diskBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) + } + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", + "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false) + } + val executorTable = new ExecutorTable(stageData, parent.store) val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

      Accumulators

      ++ accumulableTable } else Seq() + if (hasAccumulators(stageData)) {

      Accumulators

      ++ accumulableTable } else Seq() val aggMetrics = taskIdsInPage.contains(t.taskId) }, + Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), currentTime) ++

      Summary Metrics for {numCompleted} Completed Tasks

      ++
      {summaryTable.getOrElse("No tasks have reported metrics yet.")}
      ++ aggMetrics ++ maybeAccumulableTable ++ -

      Tasks ({totalTasksNumStr})

      ++ - taskTableHTML ++ jsForScrollingDownToTaskTable + +

      + + Tasks ({totalTasksNumStr}) +

      +
      ++ +
      + {taskTableHTML ++ jsForScrollingDownToTaskTable} +
      UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } @@ -593,10 +533,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskInfo, currentTime) + val gettingResultTime = AppStatusUtils.gettingResultTime(taskInfo) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = - metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelay = AppStatusUtils.schedulerDelay(taskInfo) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime @@ -708,7 +647,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We { if (MAX_TIMELINE_TASKS < tasks.size) { - This stage has more than the maximum number of tasks that can be shown in the + This page has more than the maximum number of tasks that can be shown in the visualization! Only the most recent {MAX_TIMELINE_TASKS} tasks (of {tasks.size} total) are shown. @@ -733,402 +672,49 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } -private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { - info.resultFetchStart match { - case Some(start) => - info.duration match { - case Some(duration) => - info.launchTime.getTime() + duration - start.getTime() - - case _ => - currentTime - start.getTime() - } - - case _ => - 0L - } - } - - private[ui] def getSchedulerDelay( - info: TaskData, - metrics: TaskMetrics, - currentTime: Long): Long = { - info.duration match { - case Some(duration) => - val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime - math.max( - 0, - duration - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - - case _ => - // The task is still running and the metrics like executorRunTime are not available. - 0L - } - } - -} - -private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) - -private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) - -private[ui] case class TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable: Long, - shuffleReadBlockedTimeReadable: String, - shuffleReadSortable: Long, - shuffleReadReadable: String, - shuffleReadRemoteSortable: Long, - shuffleReadRemoteReadable: String) - -private[ui] case class TaskTableRowShuffleWriteData( - writeTimeSortable: Long, - writeTimeReadable: String, - shuffleWriteSortable: Long, - shuffleWriteReadable: String) - -private[ui] case class TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable: Long, - memoryBytesSpilledReadable: String, - diskBytesSpilledSortable: Long, - diskBytesSpilledReadable: String) - -/** - * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskData to avoid creating duplicate contents during sorting the data. - */ -private[ui] class TaskTableRowData( - val index: Int, - val taskId: Long, - val attempt: Int, - val speculative: Boolean, - val status: String, - val taskLocality: String, - val executorId: String, - val host: String, - val launchTime: Long, - val duration: Long, - val formatDuration: String, - val schedulerDelay: Long, - val taskDeserializationTime: Long, - val gcTime: Long, - val serializationTime: Long, - val gettingResultTime: Long, - val peakExecutionMemoryUsed: Long, - val accumulators: Option[String], // HTML - val input: Option[TaskTableRowInputData], - val output: Option[TaskTableRowOutputData], - val shuffleRead: Option[TaskTableRowShuffleReadData], - val shuffleWrite: Option[TaskTableRowShuffleWriteData], - val bytesSpilled: Option[TaskTableRowBytesSpilledData], - val error: String, - val logs: Map[String, String]) - private[ui] class TaskDataSource( - tasks: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, + stage: StageData, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { - import StagePage._ + store: AppStatusStore) extends PagedDataSource[TaskData](pageSize) { + import ApiHelper._ // Keep an internal cache of executor log maps so that long task lists render faster. private val executorIdToLogs = new HashMap[String, Map[String, String]]() - // Convert TaskData to TaskTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data - private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - - private var _slicedTaskIds: Set[Long] = _ - - override def dataSize: Int = data.size - - override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { - val r = data.slice(from, to) - _slicedTaskIds = r.map(_.taskId).toSet - r - } + private var _tasksToShow: Seq[TaskData] = null - def slicedTaskIds: Set[Long] = _slicedTaskIds + override def dataSize: Int = stage.numTasks - private def taskRow(info: TaskData): TaskTableRowData = { - val metrics = info.taskMetrics - val duration = info.duration.getOrElse(1L) - val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info, currentTime) - - val externalAccumulableReadable = info.accumulatorUpdates.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + override def sliceData(from: Int, to: Int): Seq[TaskData] = { + if (_tasksToShow == null) { + _tasksToShow = store.taskList(stage.stageId, stage.attemptId, from, to - from, + indexName(sortColumn), !desc) } - val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) - - val maybeInput = metrics.map(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)}") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.map(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.map(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.recordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime) - val writeTimeSortable = maybeWriteTime.getOrElse(0L) - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val input = - if (hasInput) { - Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) - } else { - None - } - - val output = - if (hasOutput) { - Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) - } else { - None - } - - val shuffleRead = - if (hasShuffleRead) { - Some(TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable, - shuffleReadBlockedTimeReadable, - shuffleReadSortable, - s"$shuffleReadReadable / $shuffleReadRecords", - shuffleReadRemoteSortable, - shuffleReadRemoteReadable - )) - } else { - None - } - - val shuffleWrite = - if (hasShuffleWrite) { - Some(TaskTableRowShuffleWriteData( - writeTimeSortable, - writeTimeReadable, - shuffleWriteSortable, - s"$shuffleWriteReadable / $shuffleWriteRecords" - )) - } else { - None - } - - val bytesSpilled = - if (hasBytesSpilled) { - Some(TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable, - memoryBytesSpilledReadable, - diskBytesSpilledSortable, - diskBytesSpilledReadable - )) - } else { - None - } - - new TaskTableRowData( - info.index, - info.taskId, - info.attempt, - info.speculative, - info.status, - info.taskLocality.toString, - info.executorId, - info.host, - info.launchTime.getTime(), - duration, - formatDuration, - schedulerDelay, - taskDeserializationTime, - gcTime, - serializationTime, - gettingResultTime, - peakExecutionMemoryUsed, - if (hasAccumulators) Some(externalAccumulableReadable.mkString("
      ")) else None, - input, - output, - shuffleRead, - shuffleWrite, - bytesSpilled, - info.errorMessage.getOrElse(""), - executorLogs(info.executorId)) + _tasksToShow } - private def executorLogs(id: String): Map[String, String] = { + def tasks: Seq[TaskData] = _tasksToShow + + def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } - /** - * Return Ordering according to sortColumn and desc - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering: Ordering[TaskTableRowData] = sortColumn match { - case "Index" => Ordering.by(_.index) - case "ID" => Ordering.by(_.taskId) - case "Attempt" => Ordering.by(_.attempt) - case "Status" => Ordering.by(_.status) - case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID" => Ordering.by(_.executorId) - case "Host" => Ordering.by(_.host) - case "Launch Time" => Ordering.by(_.launchTime) - case "Duration" => Ordering.by(_.duration) - case "Scheduler Delay" => Ordering.by(_.schedulerDelay) - case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) - case "GC Time" => Ordering.by(_.gcTime) - case "Result Serialization Time" => Ordering.by(_.serializationTime) - case "Getting Result Time" => Ordering.by(_.gettingResultTime) - case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) - case "Accumulators" => - if (hasAccumulators) { - Ordering.by(_.accumulators.get) - } else { - throw new IllegalArgumentException( - "Cannot sort by Accumulators because of no accumulators") - } - case "Input Size / Records" => - if (hasInput) { - Ordering.by(_.input.get.inputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Input Size / Records because of no inputs") - } - case "Output Size / Records" => - if (hasOutput) { - Ordering.by(_.output.get.outputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Output Size / Records because of no outputs") - } - // ShuffleRead - case "Shuffle Read Blocked Time" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") - } - case "Shuffle Read Size / Records" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") - } - case "Shuffle Remote Reads" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Remote Reads because of no shuffle reads") - } - // ShuffleWrite - case "Write Time" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.writeTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Write Time because of no shuffle writes") - } - case "Shuffle Write Size / Records" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") - } - // BytesSpilled - case "Shuffle Spill (Memory)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Memory) because of no spills") - } - case "Shuffle Spill (Disk)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Disk) because of no spills") - } - case "Errors" => Ordering.by(_.error) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } - } - } private[ui] class TaskPagedTable( - conf: SparkConf, + stage: StageData, basePath: String, - data: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskData] { + + import ApiHelper._ override def tableId: String = "task-table" @@ -1142,13 +728,7 @@ private[ui] class TaskPagedTable( override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( - data, - hasAccumulators, - hasInput, - hasOutput, - hasShuffleRead, - hasShuffleWrite, - hasBytesSpilled, + stage, currentTime, pageSize, sortColumn, @@ -1170,37 +750,39 @@ private[ui] class TaskPagedTable( } def headers: Seq[Node] = { + import ApiHelper._ + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), - ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (hasShuffleRead) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + (HEADER_TASK_INDEX, ""), (HEADER_ID, ""), (HEADER_ATTEMPT, ""), (HEADER_STATUS, ""), + (HEADER_LOCALITY, ""), (HEADER_EXECUTOR, ""), (HEADER_HOST, ""), (HEADER_LAUNCH_TIME, ""), + (HEADER_DURATION, ""), (HEADER_SCHEDULER_DELAY, TaskDetailsClassNames.SCHEDULER_DELAY), + (HEADER_DESER_TIME, TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + (HEADER_GC_TIME, ""), + (HEADER_SER_TIME, TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + (HEADER_GETTING_RESULT_TIME, TaskDetailsClassNames.GETTING_RESULT_TIME), + (HEADER_PEAK_MEM, TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ + {if (hasAccumulators(stage)) Seq((HEADER_ACCUMULATORS, "")) else Nil} ++ + {if (hasInput(stage)) Seq((HEADER_INPUT_SIZE, "")) else Nil} ++ + {if (hasOutput(stage)) Seq((HEADER_OUTPUT_SIZE, "")) else Nil} ++ + {if (hasShuffleRead(stage)) { + Seq((HEADER_SHUFFLE_READ_TIME, TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + (HEADER_SHUFFLE_TOTAL_READS, ""), + (HEADER_SHUFFLE_REMOTE_READS, TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ - {if (hasShuffleWrite) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + {if (hasShuffleWrite(stage)) { + Seq((HEADER_SHUFFLE_WRITE_TIME, ""), (HEADER_SHUFFLE_WRITE_SIZE, "")) } else { Nil }} ++ - {if (hasBytesSpilled) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + {if (hasBytesSpilled(stage)) { + Seq((HEADER_MEM_SPILL, ""), (HEADER_DISK_SPILL, "")) } else { Nil }} ++ - Seq(("Errors", "")) + Seq((HEADER_ERROR, "")) if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { throw new IllegalArgumentException(s"Unknown column: $sortColumn") @@ -1237,7 +819,17 @@ private[ui] class TaskPagedTable( {headerRow} } - def row(task: TaskTableRowData): Seq[Node] = { + def row(task: TaskData): Seq[Node] = { + def formatDuration(value: Option[Long], hideZero: Boolean = false): String = { + value.map { v => + if (v > 0 || !hideZero) UIUtils.formatDuration(v) else "" + }.getOrElse("") + } + + def formatBytes(value: Option[Long]): String = { + Utils.bytesToString(value.getOrElse(0L)) + } + {task.index} {task.taskId} @@ -1249,62 +841,102 @@ private[ui] class TaskPagedTable(
      {task.host}
      { - task.logs.map { + dataSource.executorLogs(task.executorId).map { case (logName, logUrl) => } }
      - {UIUtils.formatDate(new Date(task.launchTime))} - {task.formatDuration} + {UIUtils.formatDate(task.launchTime)} + {formatDuration(task.duration)} - {UIUtils.formatDuration(task.schedulerDelay)} + {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} - {UIUtils.formatDuration(task.taskDeserializationTime)} + {formatDuration(task.taskMetrics.map(_.executorDeserializeTime))} - {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + {formatDuration(task.taskMetrics.map(_.jvmGcTime), hideZero = true)} - {UIUtils.formatDuration(task.serializationTime)} + {formatDuration(task.taskMetrics.map(_.resultSerializationTime))} - {UIUtils.formatDuration(task.gettingResultTime)} + {UIUtils.formatDuration(AppStatusUtils.gettingResultTime(task))} - {Utils.bytesToString(task.peakExecutionMemoryUsed)} + {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} - {if (task.accumulators.nonEmpty) { - {Unparsed(task.accumulators.get)} + {if (hasAccumulators(stage)) { + {accumulatorsInfo(task)} }} - {if (task.input.nonEmpty) { - {task.input.get.inputReadable} + {if (hasInput(stage)) { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead) + val records = m.inputMetrics.recordsRead + {bytesRead} / {records} + } }} - {if (task.output.nonEmpty) { - {task.output.get.outputReadable} + {if (hasOutput(stage)) { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten) + val records = m.outputMetrics.recordsWritten + {bytesWritten} / {records} + } }} - {if (task.shuffleRead.nonEmpty) { + {if (hasShuffleRead(stage)) { - {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + {formatDuration(task.taskMetrics.map(_.shuffleReadMetrics.fetchWaitTime))} - {task.shuffleRead.get.shuffleReadReadable} + { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(totalBytesRead(m.shuffleReadMetrics)) + val records = m.shuffleReadMetrics.recordsRead + Unparsed(s"$bytesRead / $records") + } + } - {task.shuffleRead.get.shuffleReadRemoteReadable} + {formatBytes(task.taskMetrics.map(_.shuffleReadMetrics.remoteBytesRead))} }} - {if (task.shuffleWrite.nonEmpty) { - {task.shuffleWrite.get.writeTimeReadable} - {task.shuffleWrite.get.shuffleWriteReadable} + {if (hasShuffleWrite(stage)) { + { + formatDuration( + task.taskMetrics.map { m => + TimeUnit.NANOSECONDS.toMillis(m.shuffleWriteMetrics.writeTime) + }, + hideZero = true) + } + { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.shuffleWriteMetrics.bytesWritten) + val records = m.shuffleWriteMetrics.recordsWritten + Unparsed(s"$bytesWritten / $records") + } + } }} - {if (task.bytesSpilled.nonEmpty) { - {task.bytesSpilled.get.memoryBytesSpilledReadable} - {task.bytesSpilled.get.diskBytesSpilledReadable} + {if (hasBytesSpilled(stage)) { + {formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))} + {formatBytes(task.taskMetrics.map(_.diskBytesSpilled))} }} - {errorMessageCell(task.error)} + {errorMessageCell(task.errorMessage.getOrElse(""))} } + private def accumulatorsInfo(task: TaskData): Seq[Node] = { + task.accumulatorUpdates.flatMap { acc => + if (acc.name != null && acc.update.isDefined) { + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")) ++
      + } else { + Nil + } + } + } + + private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = { + task.taskMetrics.map(fn).getOrElse(Nil) + } + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default @@ -1331,7 +963,66 @@ private[ui] class TaskPagedTable( } } -private object ApiHelper { +private[ui] object ApiHelper { + + val HEADER_ID = "ID" + val HEADER_TASK_INDEX = "Index" + val HEADER_ATTEMPT = "Attempt" + val HEADER_STATUS = "Status" + val HEADER_LOCALITY = "Locality Level" + val HEADER_EXECUTOR = "Executor ID" + val HEADER_HOST = "Host" + val HEADER_LAUNCH_TIME = "Launch Time" + val HEADER_DURATION = "Duration" + val HEADER_SCHEDULER_DELAY = "Scheduler Delay" + val HEADER_DESER_TIME = "Task Deserialization Time" + val HEADER_GC_TIME = "GC Time" + val HEADER_SER_TIME = "Result Serialization Time" + val HEADER_GETTING_RESULT_TIME = "Getting Result Time" + val HEADER_PEAK_MEM = "Peak Execution Memory" + val HEADER_ACCUMULATORS = "Accumulators" + val HEADER_INPUT_SIZE = "Input Size / Records" + val HEADER_OUTPUT_SIZE = "Output Size / Records" + val HEADER_SHUFFLE_READ_TIME = "Shuffle Read Blocked Time" + val HEADER_SHUFFLE_TOTAL_READS = "Shuffle Read Size / Records" + val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads" + val HEADER_SHUFFLE_WRITE_TIME = "Write Time" + val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records" + val HEADER_MEM_SPILL = "Shuffle Spill (Memory)" + val HEADER_DISK_SPILL = "Shuffle Spill (Disk)" + val HEADER_ERROR = "Errors" + + private[ui] val COLUMN_TO_INDEX = Map( + HEADER_ID -> null.asInstanceOf[String], + HEADER_TASK_INDEX -> TaskIndexNames.TASK_INDEX, + HEADER_ATTEMPT -> TaskIndexNames.ATTEMPT, + HEADER_STATUS -> TaskIndexNames.STATUS, + HEADER_LOCALITY -> TaskIndexNames.LOCALITY, + HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, + HEADER_HOST -> TaskIndexNames.HOST, + HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, + HEADER_DURATION -> TaskIndexNames.DURATION, + HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, + HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, + HEADER_GC_TIME -> TaskIndexNames.GC_TIME, + HEADER_SER_TIME -> TaskIndexNames.SER_TIME, + HEADER_GETTING_RESULT_TIME -> TaskIndexNames.GETTING_RESULT_TIME, + HEADER_PEAK_MEM -> TaskIndexNames.PEAK_MEM, + HEADER_ACCUMULATORS -> TaskIndexNames.ACCUMULATORS, + HEADER_INPUT_SIZE -> TaskIndexNames.INPUT_SIZE, + HEADER_OUTPUT_SIZE -> TaskIndexNames.OUTPUT_SIZE, + HEADER_SHUFFLE_READ_TIME -> TaskIndexNames.SHUFFLE_READ_TIME, + HEADER_SHUFFLE_TOTAL_READS -> TaskIndexNames.SHUFFLE_TOTAL_READS, + HEADER_SHUFFLE_REMOTE_READS -> TaskIndexNames.SHUFFLE_REMOTE_READS, + HEADER_SHUFFLE_WRITE_TIME -> TaskIndexNames.SHUFFLE_WRITE_TIME, + HEADER_SHUFFLE_WRITE_SIZE -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + HEADER_MEM_SPILL -> TaskIndexNames.MEM_SPILL, + HEADER_DISK_SPILL -> TaskIndexNames.DISK_SPILL, + HEADER_ERROR -> TaskIndexNames.ERROR) + + def hasAccumulators(stageData: StageData): Boolean = { + stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } + } def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 @@ -1349,4 +1040,16 @@ private object ApiHelper { metrics.localBytesRead + metrics.remoteBytesRead } + def indexName(sortColumn: String): Option[String] = { + COLUMN_TO_INDEX.get(sortColumn) match { + case Some(v) => Option(v) + case _ => throw new IllegalArgumentException(s"Invalid sort column: $sortColumn") + } + } + + def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { + val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)) + (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index be05a963f0e6..10b032084ce4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -37,10 +37,10 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def handleKillRequest(request: HttpServletRequest): Unit = { diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 827a8637b9bd..948858224d72 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -116,7 +116,7 @@ private[spark] object RDDOperationGraph extends Logging { // Use a special prefix here to differentiate this cluster from other operation clusters val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId val stageClusterName = s"Stage ${stage.stageId}" + - { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" } + { if (stage.attemptNumber == 0) "" else s" (attempt ${stage.attemptNumber})" } val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName) var rootNodeCount = 0 diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 02cee7f8c5b3..2674b9291203 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.status.api.v1.{ExecutorSummary, RDDDataDistribution, RDDPartitionInfo} import org.apache.spark.ui._ import org.apache.spark.util.Utils @@ -76,7 +76,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, - blockSortDesc) + blockSortDesc, + store.executorList(true)) _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -182,7 +183,8 @@ private[ui] class BlockDataSource( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + desc: Boolean, + executorIdToAddress: Map[String, String]) extends PagedDataSource[BlockTableRowData](pageSize) { private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) @@ -198,7 +200,10 @@ private[ui] class BlockDataSource( rddPartition.storageLevel, rddPartition.memoryUsed, rddPartition.diskUsed, - rddPartition.executors.mkString(" ")) + rddPartition.executors + .map { id => executorIdToAddress.get(id).getOrElse(id) } + .sorted + .mkString(" ")) } /** @@ -226,7 +231,8 @@ private[ui] class BlockPagedTable( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[BlockTableRowData] { + desc: Boolean, + executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] { override def tableId: String = "rdd-storage-by-block-table" @@ -243,7 +249,8 @@ private[ui] class BlockPagedTable( rddPartitions, pageSize, sortColumn, - desc) + desc, + executorSummaries.map { ex => (ex.id, ex.hostPort) }.toMap) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b8aec9890247..68d946574a37 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -41,8 +41,16 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends Nil } else {
      -

      RDDs

      - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + +

      + + RDDs +

      +
      +
      + {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
      } } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index f4a736d6d439..2bc84953a56e 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -290,7 +290,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private var _count = 0L /** - * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + * * @since 2.0.0 */ override def isZero: Boolean = _sum == 0L && _count == 0 @@ -368,6 +369,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private var _sum = 0.0 private var _count = 0L + /** + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + */ override def isZero: Boolean = _sum == 0.0 && _count == 0 override def copy(): DoubleAccumulator = { @@ -441,6 +445,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) + /** + * Returns false if this accumulator instance has any values in it. + */ override def isZero: Boolean = _list.isEmpty override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator @@ -479,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -488,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 40616421b5bc..ad0c0639521f 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -207,7 +207,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + logDebug(s"Expected a closure; got ${func.getClass.getName}") return } diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala index d73901686b70..4b6602b50aa1 100644 --- a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -33,24 +33,14 @@ private[spark] trait CommandLineUtils { private[spark] var printStream: PrintStream = System.err // scalastyle:off println - - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printMessage(str: String): Unit = printStream.println(str) + // scalastyle:on println private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") + printMessage("Error: " + str) + printMessage("Run with --help for usage help or --verbose for debug output") exitFn(1) } - // scalastyle:on println - - private[spark] def parseSparkConfProperty(pair: String): (String, String) = { - pair.split("=", 2).toSeq match { - case Seq(k, v) => (k, v) - case _ => printErrorAndExit(s"Spark config without '=': $pair") - throw new SparkException(s"Spark config without '=': $pair") - } - } - def main(args: Array[String]): Unit } diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 31d230d0fec8..21acaa95c564 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -22,9 +22,7 @@ package org.apache.spark.util * through all the elements. */ private[spark] -// scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { -// scalastyle:on private[this] var completed = false def next(): A = sub.next() diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5e60218c5740..40383fe05026 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -48,7 +48,7 @@ import org.apache.spark.storage._ * To ensure that we provide these guarantees, follow these rules when modifying these methods: * * - Never delete any JSON fields. - * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * - Any new JSON fields should be optional; use `jsonOption` when reading these fields * in `*FromJson` methods. */ private[spark] object JsonProtocol { @@ -263,7 +263,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ - ("Stage Attempt ID" -> stageInfo.attemptId) ~ + ("Stage Attempt ID" -> stageInfo.attemptNumber) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ @@ -408,7 +408,7 @@ private[spark] object JsonProtocol { ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => ("Kill Reason" -> taskKilled.reason) - case _ => Utils.emptyJson + case _ => emptyJson } ("Reason" -> reason) ~ json } @@ -422,7 +422,7 @@ private[spark] object JsonProtocol { def jobResultToJson(jobResult: JobResult): JValue = { val result = Utils.getFormattedClassName(jobResult) val json = jobResult match { - case JobSucceeded => Utils.emptyJson + case JobSucceeded => emptyJson case jobFailed: JobFailed => JObject("Exception" -> exceptionToJson(jobFailed.exception)) } @@ -573,7 +573,7 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -586,7 +586,7 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -597,11 +597,11 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] val submissionTime = - Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 - val stageInfos = Utils.jsonOption(json \ "Stage Infos") + val stageInfos = jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") @@ -613,7 +613,7 @@ private[spark] object JsonProtocol { def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] val completionTime = - Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") SparkListenerJobEnd(jobId, completionTime, jobResult) } @@ -630,15 +630,15 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) - val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val maxOnHeapMem = jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) SparkListenerBlockManagerRemoved(time, blockManagerId) } @@ -648,11 +648,11 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] - val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) + val appId = jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - val appAttemptId = Utils.jsonOption(json \ "App Attempt ID").map(_.extract[String]) - val driverLogs = Utils.jsonOption(json \ "Driver Logs").map(mapFromJson) + val appAttemptId = jsonOption(json \ "App Attempt ID").map(_.extract[String]) + val driverLogs = jsonOption(json \ "Driver Logs").map(mapFromJson) SparkListenerApplicationStart(appName, appId, time, sparkUser, appAttemptId, driverLogs) } @@ -703,19 +703,19 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + val attemptId = jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") - val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) - val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) - val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) + val details = jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") + val submissionTime = jsonOption(json \ "Submission Time").map(_.extract[Long]) + val completionTime = jsonOption(json \ "Completion Time").map(_.extract[Long]) + val failureReason = jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = { - Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -735,17 +735,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) + val attempt = jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] val executorId = (json \ "Executor ID").extract[String].intern() val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) + val speculative = jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) - val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { + val killed = jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -762,13 +762,13 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) - val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } - val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val name = jsonOption(json \ "Name").map(_.extract[String]) + val update = jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } + val internal = jsonOption(json \ "Internal").exists(_.extract[Boolean]) val countFailedValues = - Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) - val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) + jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -821,49 +821,49 @@ private[spark] object JsonProtocol { metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) // Shuffle read metrics - Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => val readMetrics = metrics.createTempShuffleReadMetrics() readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + jsonOption(readJson \ "Remote Bytes Read To Disk") .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( - Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) readMetrics.incRecordsRead( - Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } // Shuffle write metrics // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. - Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) writeMetrics.incRecordsWritten( - Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } // Output metrics - Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) outputMetrics.setRecordsWritten( - Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics - Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) inputMetrics.incRecordsRead( - Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks - Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + jsonOption(json \ "Updated Blocks").foreach { blocksJson => metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") @@ -897,7 +897,7 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + val message = jsonOption(json \ "Message").map(_.extract[String]) new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => @@ -905,9 +905,9 @@ private[spark] object JsonProtocol { val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") val fullStackTrace = - Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull + jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x - val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + val accumUpdates = jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => { acc.toInfo(Some(acc.value), None) @@ -915,21 +915,21 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => - val killReason = Utils.jsonOption(json \ "Kill Reason") + val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility // for reading those logs, we need to provide default values for all the fields. - val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) - val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) - val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + val jobId = jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => - val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) - val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + val exitCausedByApp = jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) + val executorId = jsonOption(json \ "Executor ID").map(_.extract[String]) + val reason = jsonOption(json \ "Loss Reason").map(_.extract[String]) ExecutorLostFailure( executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true), @@ -968,11 +968,11 @@ private[spark] object JsonProtocol { def rddInfoFromJson(json: JValue): RDDInfo = { val rddId = (json \ "RDD ID").extract[Int] val name = (json \ "Name").extract[String] - val scope = Utils.jsonOption(json \ "Scope") + val scope = jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val callsite = jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) val storageLevel = storageLevelFromJson(json \ "Storage Level") @@ -1029,7 +1029,7 @@ private[spark] object JsonProtocol { } def propertiesFromJson(json: JValue): Properties = { - Utils.jsonOption(json).map { value => + jsonOption(json).map { value => val properties = new Properties mapFromJson(json).foreach { case (k, v) => properties.setProperty(k, v) } properties @@ -1058,4 +1058,14 @@ private[spark] object JsonProtocol { e } + /** Return an option that translates JNothing to None */ + private def jsonOption(json: JValue): Option[JValue] = { + json match { + case JNothing => None + case value: JValue => Some(value) + } + } + + private def emptyJson: JObject = JObject(List[JField]()) + } diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 76a56298aaeb..b25a731401f2 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -81,7 +81,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { try { doPostEvent(listener, event) } catch { - case NonFatal(e) => + case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { if (maybeTimerContext != null) { @@ -97,6 +97,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ protected def doPostEvent(listener: L, event: E): Unit + /** Allows bus implementations to prevent error logging for certain exceptions. */ + protected def isIgnorableException(e: Throwable): Boolean = false + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4001fac3c3d5..b702838fa257 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -143,7 +143,7 @@ private[spark] object ShutdownHookManager extends Logging { } /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * Adds a shutdown hook with the given priority. Hooks with higher priority values run * first. * * @param hook The code to run during shutdown. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 585330297314..dcad1b914038 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} +import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -51,9 +51,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException -import org.json4s._ import org.slf4j.Logger import org.apache.spark._ @@ -63,6 +61,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -672,7 +671,6 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } - Utils.setupSecureURLConnection(uc, securityMgr) val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 @@ -1017,70 +1015,18 @@ private[spark] object Utils extends Logging { " " + (System.currentTimeMillis - startTimeMs) + " ms" } - private def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - /** - * Lists files recursively. - */ - def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } - /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. * Throws an exception if deletion is unsuccessful. */ - def deleteRecursively(file: File) { + def deleteRecursively(file: File): Unit = { if (file != null) { - try { - if (file.isDirectory && !isSymlink(file)) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - ShutdownHookManager.removeShutdownDeleteDir(file) - } - } finally { - if (file.delete()) { - logTrace(s"${file.getAbsolutePath} has been deleted") - } else { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } + JavaUtils.deleteRecursively(file) + ShutdownHookManager.removeShutdownDeleteDir(file) } } - /** - * Check to see if file is a symbolic link. - */ - def isSymlink(file: File): Boolean = { - return Files.isSymbolicLink(Paths.get(file.toURI)) - } - /** * Determines if a directory contains any files newer than cutoff seconds. * @@ -1828,7 +1774,7 @@ private[spark] object Utils extends Logging { * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower * in the current version of Scala. */ - def getIteratorSize[T](iterator: Iterator[T]): Long = { + def getIteratorSize(iterator: Iterator[_]): Long = { var count = 0L while (iterator.hasNext) { count += 1L @@ -1875,17 +1821,6 @@ private[spark] object Utils extends Logging { obj.getClass.getSimpleName.replace("$", "") } - /** Return an option that translates JNothing to None */ - def jsonOption(json: JValue): Option[JValue] = { - json match { - case JNothing => None - case value: JValue => Some(value) - } - } - - /** Return an empty JSON object */ - def emptyJson: JsonAST.JObject = JObject(List[JField]()) - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -1900,15 +1835,6 @@ private[spark] object Utils extends Logging { getHadoopFileSystem(new URI(path), conf) } - /** - * Return the absolute path of a file in the given directory. - */ - def getFilePath(dir: File, fileName: String): Path = { - assert(dir.isDirectory) - val path = new File(dir, fileName).getAbsolutePath - new Path(path) - } - /** * Whether the underlying operating system is Windows. */ @@ -1931,13 +1857,6 @@ private[spark] object Utils extends Logging { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } - /** - * Strip the directory from a path name - */ - def stripDirectory(path: String): String = { - new File(path).getName - } - /** * Terminates a process waiting for at most the specified duration. * @@ -2168,7 +2087,22 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + threadInfos.sortWith { case (threadTrace1, threadTrace2) => + val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT) + val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT) + val nameCmpRes = name1.compareTo(name2) + if (nameCmpRes == 0) { + threadTrace1.getThreadId < threadTrace2.getThreadId + } else { + nameCmpRes < 0 + } + } else { + v1 > v2 + } + }.map(threadInfoToThreadStackTrace) } def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { @@ -2184,14 +2118,14 @@ private[spark] object Utils extends Logging { private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => + val stackTrace = StackTrace(threadInfo.getStackTrace.map { frame => monitors.get(frame) match { case Some(monitor) => monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" case None => frame.toString } - }.mkString("\n") + }) // use a set to dedup re-entrant locks that are held at multiple places val heldLocks = @@ -2333,50 +2267,6 @@ private[spark] object Utils extends Logging { org.apache.log4j.Logger.getRootLogger().setLevel(l) } - /** - * config a log4j properties used for testsuite - */ - def configTestLog4j(level: String): Unit = { - val pro = new Properties() - pro.put("log4j.rootLogger", s"$level, console") - pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") - pro.put("log4j.appender.console.target", "System.err") - pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") - pro.put("log4j.appender.console.layout.ConversionPattern", - "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") - PropertyConfigurator.configure(pro) - } - - /** - * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and - * the host verifier from the given security manager. - */ - def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { - urlConnection match { - case https: HttpsURLConnection => - sm.sslSocketFactory.foreach(https.setSSLSocketFactory) - sm.hostnameVerifier.foreach(https.setHostnameVerifier) - https - case connection => connection - } - } - - def invoke( - clazz: Class[_], - obj: AnyRef, - methodName: String, - args: (Class[_], AnyRef)*): AnyRef = { - val (types, values) = args.unzip - val method = clazz.getDeclaredMethod(methodName, types: _*) - method.setAccessible(true) - method.invoke(obj, values.toSeq: _*) - } - - // Limit of bytes for total size of results (default is 1GB) - def getMaxResultSize(conf: SparkConf): Long = { - memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 - } - /** * Return the current system LD_LIBRARY_PATH name */ @@ -2412,16 +2302,20 @@ private[spark] object Utils extends Logging { } /** - * Return the value of a config either through the SparkConf or the Hadoop configuration - * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf - * if the key is not set in the Hadoop configuration. + * Return the value of a config either through the SparkConf or the Hadoop configuration. + * We Check whether the key is set in the SparkConf before look at any Hadoop configuration. + * If the key is set in SparkConf, no matter whether it is running on YARN or not, + * gets the value from SparkConf. + * Only when the key is not set in SparkConf and running on YARN, + * gets the value from Hadoop configuration. */ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { - val sparkValue = conf.get(key, default) - if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { - new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue) + if (conf.contains(key)) { + conf.get(key, default) + } else if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { + new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, default) } else { - sparkValue + default } } @@ -2609,16 +2503,6 @@ private[spark] object Utils extends Logging { SignalUtils.registerLogger(log) } - /** - * Unions two comma-separated lists of files and filters out empty strings. - */ - def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { - var allFiles = Set.empty[String] - leftList.foreach { value => allFiles ++= value.split(",") } - rightList.foreach { value => allFiles ++= value.split(",") } - allFiles.filter { _.nonEmpty } - } - /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN @@ -2805,6 +2689,21 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 375f4a692122..5c6dd45ec58e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -463,7 +463,7 @@ class ExternalAppendOnlyMap[K, V, C]( // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() + private var deserializeStream: DeserializationStream = null private var nextItem: (K, C) = null private var objectsRead = 0 @@ -528,7 +528,11 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { if (deserializeStream == null) { - return false + // In case of deserializeStream has not been initialized + deserializeStream = nextBatchStream() + if (deserializeStream == null) { + return false + } } nextItem = readNextItem() } @@ -536,19 +540,18 @@ class ExternalAppendOnlyMap[K, V, C]( } override def next(): (K, C) = { - val item = if (nextItem == null) readNextItem() else nextItem - if (item == null) { + if (!hasNext) { throw new NoSuchElementException } + val item = nextItem nextItem = null item } private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - if (ds != null) { - ds.close() + if (deserializeStream != null) { + deserializeStream.close() deserializeStream = null } if (fileStream != null) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 8183f825592c..81457b53cd81 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** @@ -41,7 +42,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - protected def elementsRead: Long = _elementsRead + protected def elementsRead: Int = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -54,15 +55,15 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // Force this collection to spill when there are this many elements in memory // For testing only - private[this] val numElementsForceSpillThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + private[this] val numElementsForceSpillThreshold: Int = + SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 @volatile private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill - private[this] var _elementsRead = 0L + private[this] var _elementsRead = 0 // Number of bytes spilled in total @volatile private[this] var _memoryBytesSpilled = 0L diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd..3ae8dfcc1cb6 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) + val curChunkLimit = bytes.limit() + while (bytes.hasRemaining) { + try { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + } finally { + bytes.limit(curChunkLimit) + } } } } diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 3440e1aea2f4..22db3592ecc9 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite { protected File inputFile; - protected InputStream inputStream; + protected InputStream[] inputStreams; @Before public void setUp() throws IOException { @@ -54,77 +54,91 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - for (int i = 0; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testReadMultipleBytes() throws IOException { - byte[] readBytes = new byte[8 * 1024]; - int i = 0; - while (i < randomBytes.length) { - int read = inputStream.read(readBytes, 0, 8 * 1024); - for (int j = 0; j < read; j++) { - assertEquals(randomBytes[i], readBytes[j]); - i++; + for (InputStream inputStream: inputStreams) { + byte[] readBytes = new byte[8 * 1024]; + int i = 0; + while (i < randomBytes.length) { + int read = inputStream.read(readBytes, 0, 8 * 1024); + for (int j = 0; j < read; j++) { + assertEquals(randomBytes[i], readBytes[j]); + i++; + } } } } @Test public void testBytesSkipped() throws IOException { - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - // Skipping negative bytes should essential be a no-op - assertEquals(0, inputStream.skip(-1)); - assertEquals(0, inputStream.skip(-1024)); - assertEquals(0, inputStream.skip(Long.MIN_VALUE)); - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + // Skipping negative bytes should essential be a no-op + assertEquals(0, inputStream.skip(-1)); + assertEquals(0, inputStream.skip(-1024)); + assertEquals(0, inputStream.skip(Long.MIN_VALUE)); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testSkipFromFileChannel() throws IOException { - // Since the buffer is smaller than the skipped bytes, this will guarantee - // we skip from underlying file channel. - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < 2048; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(256, inputStream.skip(256)); - assertEquals(256, inputStream.skip(256)); - assertEquals(512, inputStream.skip(512)); - for (int i = 3072; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + // Since the buffer is smaller than the skipped bytes, this will guarantee + // we skip from underlying file channel. + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < 2048; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(256, inputStream.skip(256)); + assertEquals(256, inputStream.skip(256)); + assertEquals(512, inputStream.skip(512)); + for (int i = 3072; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterEOF() throws IOException { - assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); - assertEquals(-1, inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); + assertEquals(-1, inputStream.read()); + } } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java index 211b33a1a9fb..a320f8662f70 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -18,6 +18,7 @@ import org.junit.Before; +import java.io.InputStream; import java.io.IOException; /** @@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new NioBufferedFileInputStream(inputFile); + inputStreams = new InputStream[] { + new NioBufferedFileInputStream(inputFile), // default + new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer + }; } } diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java index 918ddc4517ec..bfa1e0b90882 100644 --- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -19,16 +19,27 @@ import org.junit.Before; import java.io.IOException; +import java.io.InputStream; /** - * Tests functionality of {@link NioBufferedFileInputStream} + * Tests functionality of {@link ReadAheadInputStreamSuite} */ public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new ReadAheadInputStream( - new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + inputStreams = new InputStream[] { + // Tests equal and aligned buffers of wrapped an outer stream. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 1024), 8 * 1024), + // Tests aligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 1024), 2 * 1024), + // Tests aligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 1024), 3 * 1024), + // Tests unaligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 123), + // Tests unaligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 321) + }; } } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index c2261c204cd4..6a1a38c1a54f 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -30,6 +31,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -107,7 +109,7 @@ public void testChildProcLauncher() throws Exception { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.appender=childproc") + "-Dfoo=bar -Dtest.appender=console") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) @@ -133,6 +135,12 @@ public void testInProcessLauncher() throws Exception { p.put(e.getKey(), e.getValue()); } System.setProperties(p); + // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. + // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. + // See SPARK-23019 and SparkContext.stop() for details. + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -141,26 +149,82 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); + SparkAppHandle handle = null; + try { + synchronized (InProcessTestApp.LOCK) { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + // SPARK-23020: see doc for InProcessTestApp.LOCK for a description of the race. Here + // we wait until we know that the connection between the app and the launcher has been + // established before allowing the app to finish. + final SparkAppHandle _handle = handle; + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertNotEquals(SparkAppHandle.State.UNKNOWN, _handle.getState()); + }); + + InProcessTestApp.LOCK.wait(5000); + } + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } + } + + @Test + public void testInProcessLauncherDoesNotKillJvm() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + List wrongArgs = Arrays.asList( + new String[] { "--unknown" }, + new String[] { opts.DEPLOY_MODE, "invalid" }); + + for (String[] args : wrongArgs) { + InProcessLauncher launcher = new InProcessLauncher() + .setAppResource(SparkLauncher.NO_RESOURCE); + switch (args.length) { + case 2: + launcher.addSparkArg(args[0], args[1]); + break; + + case 1: + launcher.addSparkArg(args[0]); + break; + + default: + fail("FIXME: invalid test."); + } + SparkAppHandle handle = launcher.startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.FAILED, handle.getState()); + } + + // Run --version, which is useless as a use case, but should succeed and not exit the JVM. + // The expected state is "LOST" since "--version" doesn't report state back to the handle. + SparkAppHandle handle = new InProcessLauncher().addSparkArg(opts.VERSION).startApplication(); waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + assertEquals(SparkAppHandle.State.LOST, handle.getState()); } public static class SparkLauncherTestApp { @@ -176,10 +240,26 @@ public static void main(String[] args) throws Exception { public static class InProcessTestApp { + /** + * SPARK-23020: there's a race caused by a child app finishing too quickly. This would cause + * the InProcessAppHandle to dispose of itself even before the child connection was properly + * established, so no state changes would be detected for the application and its final + * state would be LOST. + * + * It's not really possible to fix that race safely in the handle code itself without changing + * the way in-process apps talk to the launcher library, so we work around that in the test by + * synchronizing on this object. + */ + public static final Object LOCK = new Object(); + public static void main(String[] args) throws Exception { assertNotEquals(0, args.length); assertEquals(args[0], "hello"); new SparkContext().stop(); + + synchronized (LOCK) { + LOCK.notifyAll(); + } } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 46b0516e3614..d7d2d0b012bd 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -21,6 +21,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; public class TaskMemoryManagerSuite { @@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() { Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); } + @Test + public void freeingPageSetsPageNumberToSpecialConstant() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + c.freePage(dataPage); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); + } + + @Test(expected = AssertionError.class) + public void freeingPageDirectlyInAllocatorTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + MemoryAllocator.HEAP.free(dataPage); + } + + @Test(expected = AssertionError.class) + public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256); + manager.freePage(dataPage, c); + } + @Test public void cooperativeSpilling() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 24a55df84a24..0d5c5ea7903e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -95,7 +95,7 @@ public void tearDown() { @SuppressWarnings("unchecked") public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir("test", "test"); + tempDir = Utils.createTempDir(null, "test"); mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index af4975c888d6..411cd5cb5733 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 594f07dd780f..c14553232851 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; @@ -164,8 +166,10 @@ public void freeAfterOOM() { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1..f8e233a05a44 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index f2c3ec5da889..4fecf84db65a 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,9 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -125,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -140,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json new file mode 100644 index 000000000000..5e9e8230e274 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -0,0 +1,639 @@ +{ + "status": "COMPLETE", + "stageId": 0, + "attemptId": 0, + "numTasks": 10, + "numActiveTasks": 0, + "numCompleteTasks": 10, + "numFailedTasks": 2, + "numKilledTasks": 0, + "numCompletedIndices": 10, + "executorRunTime": 761, + "executorCpuTime": 269916000, + "submissionTime": "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", + "completionTime": "2018-01-09T10:21:19.062GMT", + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleReadBytes": 0, + "shuffleReadRecords": 0, + "shuffleWriteBytes": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "name": "map at :26", + "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool": "default", + "rddIds": [ + 1, + 0 + ], + "accumulatorUpdates": [], + "tasks": { + "0": { + "taskId": 0, + "index": 0, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.347GMT", + "duration": 562, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 460, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 14, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 3873006, + "recordsWritten": 0 + } + } + }, + "5": { + "taskId": 5, + "index": 3, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.958GMT", + "duration": 22, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2586000, + "executorRunTime": 9, + "executorCpuTime": 9635000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 262919, + "recordsWritten": 1 + } + } + }, + "10": { + "taskId": 10, + "index": 8, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.034GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1803000, + "executorRunTime": 6, + "executorCpuTime": 6157000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 243647, + "recordsWritten": 1 + } + } + }, + "1": { + "taskId": 1, + "index": 1, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.364GMT", + "duration": 565, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 301, + "executorDeserializeCpuTime": 200029000, + "executorRunTime": 212, + "executorCpuTime": 198479000, + "resultSize": 1115, + "jvmGcTime": 13, + "resultSerializationTime": 1, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 2409488, + "recordsWritten": 1 + } + } + }, + "6": { + "taskId": 6, + "index": 4, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.980GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2610000, + "executorRunTime": 10, + "executorCpuTime": 9622000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 385110, + "recordsWritten": 1 + } + } + }, + "9": { + "taskId": 9, + "index": 7, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.022GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1981000, + "executorRunTime": 7, + "executorCpuTime": 6335000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 259354, + "recordsWritten": 1 + } + } + }, + "2": { + "taskId": 2, + "index": 2, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.899GMT", + "duration": 27, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 16, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 126128, + "recordsWritten": 0 + } + } + }, + "7": { + "taskId": 7, + "index": 5, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.996GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2231000, + "executorRunTime": 9, + "executorCpuTime": 8407000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 205520, + "recordsWritten": 1 + } + } + }, + "3": { + "taskId": 3, + "index": 0, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.919GMT", + "duration": 24, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 8, + "executorDeserializeCpuTime": 8878000, + "executorRunTime": 10, + "executorCpuTime": 9364000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 207014, + "recordsWritten": 1 + } + } + }, + "11": { + "taskId": 11, + "index": 9, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.045GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2017000, + "executorRunTime": 6, + "executorCpuTime": 6676000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 233652, + "recordsWritten": 1 + } + } + }, + "8": { + "taskId": 8, + "index": 6, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.011GMT", + "duration": 11, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 1, + "executorDeserializeCpuTime": 1554000, + "executorRunTime": 7, + "executorCpuTime": 6034000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 213296, + "recordsWritten": 1 + } + } + }, + "4": { + "taskId": 4, + "index": 2, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.943GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2211000, + "executorRunTime": 9, + "executorCpuTime": 9207000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 292381, + "recordsWritten": 1 + } + } + } + }, + "executorSummary": { + "0": { + "taskTime": 589, + "failedTasks": 2, + "succeededTasks": 0, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 0, + "shuffleWriteRecords": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": true + }, + "1": { + "taskTime": 708, + "failedTasks": 0, + "succeededTasks": 10, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": false + } + }, + "killedTasksSummary": {} +} diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json new file mode 100644 index 000000000000..acd4cc53de6c --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -0,0 +1,783 @@ +{ + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 4, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 5080, + "executorCpuTime" : 1163210819, + "submissionTime" : "2018-01-18T18:33:12.658GMT", + "firstTaskLaunchedTime" : "2018-01-18T18:33:12.816GMT", + "completionTime" : "2018-01-18T18:33:15.279GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 1461, + "shuffleWriteRecords" : 30, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :27", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.816GMT", + "duration" : 2064, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1081, + "executorDeserializeCpuTime" : 353981050, + "executorRunTime" : 914, + "executorCpuTime" : 368865439, + "resultSize" : 1134, + "jvmGcTime" : 75, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3662221, + "recordsWritten" : 3 + } + } + }, + "5" : { + "taskId" : 5, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.320GMT", + "duration" : 73, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 27, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 191901, + "recordsWritten" : 0 + } + } + }, + "10" : { + "taskId" : 10, + "index" : 1, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:15.069GMT", + "duration" : 132, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4598966, + "executorRunTime" : 76, + "executorCpuTime" : 20826337, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 301705, + "recordsWritten" : 3 + } + } + }, + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1506, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1332, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 33, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3075188, + "recordsWritten" : 0 + } + } + }, + "6" : { + "taskId" : 6, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.323GMT", + "duration" : 67, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 51, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 183718, + "recordsWritten" : 0 + } + } + }, + "9" : { + "taskId" : 9, + "index" : 4, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.973GMT", + "duration" : 96, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4793905, + "executorRunTime" : 48, + "executorCpuTime" : 25678331, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 366050, + "recordsWritten" : 3 + } + } + }, + "13" : { + "taskId" : 13, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.200GMT", + "duration" : 76, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 25, + "executorDeserializeCpuTime" : 5860574, + "executorRunTime" : 25, + "executorCpuTime" : 20585619, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 369513, + "recordsWritten" : 3 + } + } + }, + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1774, + "executorId" : "3", + "host" : "apiros-2.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1206, + "executorDeserializeCpuTime" : 263386625, + "executorRunTime" : 493, + "executorCpuTime" : 278399617, + "resultSize" : 1134, + "jvmGcTime" : 78, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3322956, + "recordsWritten" : 3 + } + } + }, + "12" : { + "taskId" : 12, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.165GMT", + "duration" : 60, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4010338, + "executorRunTime" : 34, + "executorCpuTime" : 21657558, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 319101, + "recordsWritten" : 3 + } + } + }, + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.859GMT", + "duration" : 115, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 10894331, + "executorRunTime" : 84, + "executorCpuTime" : 28283110, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 377601, + "recordsWritten" : 3 + } + } + }, + "3" : { + "taskId" : 3, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 2027, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1282, + "executorDeserializeCpuTime" : 365807898, + "executorRunTime" : 681, + "executorCpuTime" : 349920830, + "resultSize" : 1134, + "jvmGcTime" : 102, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 3587839, + "recordsWritten" : 3 + } + } + }, + "11" : { + "taskId" : 11, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.072GMT", + "duration" : 93, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4239884, + "executorRunTime" : 77, + "executorCpuTime" : 21689428, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 323898, + "recordsWritten" : 3 + } + } + }, + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.879GMT", + "duration" : 194, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 56, + "executorDeserializeCpuTime" : 12246145, + "executorRunTime" : 54, + "executorCpuTime" : 27304550, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 311940, + "recordsWritten" : 3 + } + } + }, + "4" : { + "taskId" : 4, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 1522, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1184, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 82, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 16858066, + "recordsWritten" : 0 + } + } + } + }, + "executorSummary" : { + "4" : { + "taskTime" : 1589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "5" : { + "taskTime" : 1579, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "1" : { + "taskTime" : 2411, + "failedTasks" : 0, + "succeededTasks" : 4, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 585, + "shuffleWriteRecords" : 12, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "2" : { + "taskTime" : 2446, + "failedTasks" : 0, + "succeededTasks" : 5, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 732, + "shuffleWriteRecords" : 15, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "3" : { + "taskTime" : 1774, + "failedTasks" : 0, + "succeededTasks" : 1, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 144, + "shuffleWriteRecords" : 3, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + } + }, + "killedTasksSummary" : { } +} \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index c925c1dd8a4d..4fecf84db65a 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,10 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -126,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -141,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 942e6d8f0436..7bb8fe8fd8f9 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -19,5 +19,6 @@ "isBlacklisted" : false, "maxMemory" : 278302556, "addTime" : "2015-02-03T16:43:00.906GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index ed33c90dd39b..dd5b1dcb7372 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ,{ "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 73519f1d9e2e..3e55d3d9d7eb 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 6931fead3d2f..e87f3e78f2dc 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -19,7 +19,8 @@ "isBlacklisted" : false, "maxMemory" : 384093388, "addTime" : "2016-11-15T23:20:38.836GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.111:64543", @@ -44,7 +45,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.111:64539", @@ -69,7 +71,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -94,7 +97,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.111:64540", @@ -119,5 +123,6 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index cc0b2b0022bd..79950b0dc648 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,46 +1,46 @@ [ { - "id" : "app-20161116163331-0000", + "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-16T22:33:29.916GMT", - "endTime" : "2016-11-16T22:33:40.587GMT", + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", "lastUpdated" : "", - "duration" : 10671, - "sparkUser" : "jose", + "duration" : 472819, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, - "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 } ] }, { - "id" : "app-20161115172038-0000", + "id" : "app-20180109111548-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-15T23:20:37.079GMT", - "endTime" : "2016-11-15T23:22:18.874GMT", + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", "lastUpdated" : "", - "duration" : 101795, - "sparkUser" : "jose", + "duration" : 535234, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, - "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 } ] }, { - "id" : "local-1430917381534", + "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:11.398GMT", + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", "lastUpdated" : "", - "duration" : 10505, - "sparkUser" : "irashid", + "duration" : 10671, + "sparkUser" : "jose", "completed" : true, - "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, - "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.1.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, + "endTimeEpoch" : 1479335620587 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 5af50abd8533..7d60977dcd4f 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,8 +140,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index 7f896c74b5be..dfbfd8aedcc2 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,8 +39,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, "endTimeEpoch" : 1479335620587 } ] }, { @@ -24,8 +54,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479252037079, "endTimeEpoch" : 1479252138874 } ] }, { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 31093a661663..03f886afa541 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 601d70695b17..947c89906955 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index f8e27703c0de..5c42ac1d87f4 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 6.0, 53.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index a28bda16a956..e6b705989cc9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 4.0, 4.0, 6.0, 7.0, 9.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index ede3eaed1d1d..788f28cf7b36 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 4.0, 6.0, 13.0, 40.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 9cdcef074618..963f010968b6 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -465,7 +465,8 @@ "shuffleWrite" : 0, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/spark-events/app-20180109111548-0000 b/core/src/test/resources/spark-events/app-20180109111548-0000 new file mode 100755 index 000000000000..50893d3001b9 --- /dev/null +++ b/core/src/test/resources/spark-events/app-20180109111548-0000 @@ -0,0 +1,59 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","Java Version":"1.8.0_152 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"172.30.65.138","spark.eventLog.enabled":"true","spark.driver.port":"64273","spark.repl.class.uri":"spark://172.30.65.138:64273/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/spark-811c1b49-eb66-4bfb-91ae-33b45efa269d/repl-c4438f51-ee23-41ed-8e04-71496e2f40f5","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"local-cluster[2,1,1024]","spark.home":"*********(redacted)","spark.sql.catalogImplementation":"in-memory","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.app.id":"app-20180109111548-0000"},"System Properties":{"java.io.tmpdir":"/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib","user.dir":"*********(redacted)","java.library.path":"*********(redacted)","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.152-b16","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_152-b16","java.vm.info":"mixed mode","java.ext.dirs":"*********(redacted)","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.12.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"*********(redacted)","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[2,1,1024] --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","java.version":"1.8.0_152","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/conf/":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-client-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"*********(redacted)"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20180109111548-0000","Timestamp":1515492942372,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965588,"Executor ID":"0","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout","stderr":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965598,"Executor ID":"1","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout","stderr":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.30.65.138","Port":64290},"Maximum Memory":384093388,"Timestamp":1515492965643,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.30.65.138","Port":64291},"Maximum Memory":384093388,"Timestamp":1515492965652,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1515493278122,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1515493278918,"executorId":"0","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"460","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"14","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3873006","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278909,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3873006,"Value":3873006,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":14,"Value":14,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":460,"Value":460,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":460,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":14,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3873006,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"16","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"126128","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278926,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":126128,"Value":3999134,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":476,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":16,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":126128,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278929,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":2409488,"Value":6408622,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":46,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":896,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":13,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1115,"Value":1115,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":198479000,"Value":198479000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":212,"Value":688,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":200029000,"Value":200029000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":301,"Value":301,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":301,"Executor Deserialize CPU Time":200029000,"Executor Run Time":212,"Executor CPU Time":198479000,"Result Size":1115,"JVM GC Time":13,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":2409488,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":8,"Executor Deserialize CPU Time":8878000,"Executor Run Time":10,"Executor CPU Time":9364000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":207014,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2211000,"Executor Run Time":9,"Executor CPU Time":9207000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":292381,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2586000,"Executor Run Time":9,"Executor CPU Time":9635000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":262919,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278996,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":385110,"Value":7556046,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":230,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":4528,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":5231,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9622000,"Value":236307000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":726,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2610000,"Value":216314000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":317,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2610000,"Executor Run Time":10,"Executor CPU Time":9622000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":385110,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279011,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":205520,"Value":7761566,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":276,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":5440,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":6260,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":8407000,"Value":244714000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":735,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2231000,"Value":218545000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":319,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2231000,"Executor Run Time":9,"Executor CPU Time":8407000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":205520,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279022,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":213296,"Value":7974862,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":6352,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":7289,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6034000,"Value":250748000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":742,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1554000,"Value":220099000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1,"Value":320,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1,"Executor Deserialize CPU Time":1554000,"Executor Run Time":7,"Executor CPU Time":6034000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":213296,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279034,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":259354,"Value":8234216,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":368,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":7264,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":8318,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6335000,"Value":257083000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":749,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1981000,"Value":222080000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":322,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1981000,"Executor Run Time":7,"Executor CPU Time":6335000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":259354,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279046,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":243647,"Value":8477863,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":414,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":8176,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":9347,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6157000,"Value":263240000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":755,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1803000,"Value":223883000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":324,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1803000,"Executor Run Time":6,"Executor CPU Time":6157000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":243647,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279060,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":233652,"Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":460,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":9088,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":10376,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6676000,"Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":761,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2017000,"Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":327,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2017000,"Executor Run Time":6,"Executor CPU Time":6676000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":233652,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Completion Time":1515493279062,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":761,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10376,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":10,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":9088,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":460,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":327,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279152,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":184,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":1286,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":41280000,"Value":41280000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":53,"Value":53,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":11820000,"Value":11820000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":17,"Value":17,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":17,"Executor Deserialize CPU Time":11820000,"Executor Run Time":53,"Executor CPU Time":41280000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":184,"Total Records Read":4},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279167,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":138,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":2572,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7673000,"Value":48953000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":61,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1706000,"Value":13526000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":19,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1706000,"Executor Run Time":8,"Executor CPU Time":7673000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":3,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":138,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279180,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3706,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6972000,"Value":55925000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":68,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1569000,"Value":15095000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":21,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1569000,"Executor Run Time":7,"Executor CPU Time":6972000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279190,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4840,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":60830000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":73,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1882000,"Value":16977000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":23,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1882000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279194,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":23,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":138,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":6126,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":56742000,"Value":117572000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":89,"Value":162,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12625000,"Value":29602000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":18,"Value":41,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":18,"Executor Deserialize CPU Time":12625000,"Executor Run Time":89,"Executor CPU Time":56742000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":3,"Local Blocks Fetched":0,"Fetch Wait Time":23,"Remote Bytes Read":138,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279203,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7260,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6476000,"Value":124048000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":169,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1890000,"Value":31492000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":43,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1890000,"Executor Run Time":7,"Executor CPU Time":6476000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279216,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":8394,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6927000,"Value":130975000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":176,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2038000,"Value":33530000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":45,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2038000,"Executor Run Time":7,"Executor CPU Time":6927000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279218,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":9528,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11214000,"Value":142189000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":16,"Value":192,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2697000,"Value":36227000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":49,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":2697000,"Executor Run Time":16,"Executor CPU Time":11214000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279226,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":10662,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":147094000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":197,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1700000,"Value":37927000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":51,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1700000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279232,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7850000,"Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":205,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2186000,"Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":54,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2186000,"Executor Run Time":8,"Executor CPU Time":7850000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Completion Time":1515493279232,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":23,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":138,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":322,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":54,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":7,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":205,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":10,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1515493279237,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1515493477606} diff --git a/core/src/test/resources/spark-events/application_1516285256255_0012 b/core/src/test/resources/spark-events/application_1516285256255_0012 new file mode 100755 index 000000000000..3e1736c3fe22 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1516285256255_0012 @@ -0,0 +1,71 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","Java Version":"1.8.0_161 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"apiros-1.gce.test.com","spark.eventLog.enabled":"true","spark.driver.port":"33058","spark.repl.class.uri":"spark://apiros-1.gce.test.com:33058/classes","spark.jars":"","spark.repl.class.outputDir":"/tmp/spark-6781fb17-e07a-4b32-848b-9936c2e88b33/repl-c0fd7008-04be-471e-a173-6ad3e62d53d7","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"1","spark.scheduler.mode":"FIFO","spark.executor.instances":"8","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"yarn","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/github/spark","spark.sql.catalogImplementation":"hive","spark.driver.appUIAddress":"http://apiros-1.gce.test.com:4040","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"apiros-1.gce.test.com","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://apiros-1.gce.test.com:8088/proxy/application_1516285256255_0012","spark.app.id":"application_1516285256255_0012"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.161-b14","java.endorsed.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/endorsed","java.runtime.version":"1.8.0_161-b14","java.vm.info":"mixed mode","java.ext.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"OpenJDK Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/resources.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/rt.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/sunrsasign.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jsse.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jce.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/charsets.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jfr.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"3.10.0-693.5.2.el7.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"OpenJDK 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.blacklist.stage.maxFailedExecutorsPerNode=1 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell --executor-memory 2G --num-executors 8 spark-shell","java.home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","java.version":"1.8.0_161","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-hive_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/github/spark/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libthrift-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stringtemplate-3.2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/JavaEWAH-0.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javolution-5.5.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libfb303-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jodd-core-3.5.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-0.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ST4-4.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-core-3.2.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-servlet-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/etc/hadoop/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-yarn_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-client-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-runtime-3.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jta-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/derby-10.12.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-logging-1.1.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-pool-1.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-web-proxy-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-2.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jdo-api-3.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-dbcp-1.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1516285256255_0012","Timestamp":1516300235119,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252095,"Executor ID":"2","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"apiros-3.gce.test.com","Port":38670},"Maximum Memory":956615884,"Timestamp":1516300252260,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252715,"Executor ID":"3","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252918,"Executor ID":"1","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"apiros-2.gce.test.com","Port":38641},"Maximum Memory":956615884,"Timestamp":1516300252959,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"apiros-3.gce.test.com","Port":34970},"Maximum Memory":956615884,"Timestamp":1516300252988,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300253542,"Executor ID":"4","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"apiros-2.gce.test.com","Port":33229},"Maximum Memory":956615884,"Timestamp":1516300253653,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300254323,"Executor ID":"5","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"apiros-2.gce.test.com","Port":45147},"Maximum Memory":956615884,"Timestamp":1516300254385,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1516300392631,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394348,"executorId":"5","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklistedForStage","time":1516300394348,"hostId":"apiros-2.gce.test.com","executorFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394356,"executorId":"4","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1332","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"33","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3075188","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394338,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3075188,"Value":3075188,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1332,"Value":1332,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1332,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3075188,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1184","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"82","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"16858066","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394355,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":16858066,"Value":19933254,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":82,"Value":115,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1184,"Value":2516,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1184,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":82,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":16858066,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"51","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"183718","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394390,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":183718,"Value":20116972,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":51,"Value":2567,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":51,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":183718,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"27","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"191901","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394393,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":191901,"Value":20308873,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":27,"Value":2594,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":27,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":191901,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394606,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3322956,"Value":23631829,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":1080,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":78,"Value":193,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":278399617,"Value":278399617,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":493,"Value":3087,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":263386625,"Value":263386625,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1206,"Value":1206,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1206,"Executor Deserialize CPU Time":263386625,"Executor Run Time":493,"Executor CPU Time":278399617,"Result Size":1134,"JVM GC Time":78,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3322956,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394860,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3587839,"Value":27219668,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":291,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":2160,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":102,"Value":295,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":349920830,"Value":628320447,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":681,"Value":3768,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":365807898,"Value":629194523,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1282,"Value":2488,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1282,"Executor Deserialize CPU Time":365807898,"Executor Run Time":681,"Executor CPU Time":349920830,"Result Size":1134,"JVM GC Time":102,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":3587839,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394880,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3662221,"Value":30881889,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":435,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":3240,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":75,"Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":368865439,"Value":997185886,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":914,"Value":4682,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":353981050,"Value":983175573,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1081,"Value":3569,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1081,"Executor Deserialize CPU Time":353981050,"Executor Run Time":914,"Executor CPU Time":368865439,"Result Size":1134,"JVM GC Time":75,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3662221,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394974,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":377601,"Value":31259490,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":582,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":4320,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":4450,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":28283110,"Value":1025468996,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":84,"Value":4766,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":10894331,"Value":994069904,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":11,"Value":3580,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":11,"Executor Deserialize CPU Time":10894331,"Executor Run Time":84,"Executor CPU Time":28283110,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":377601,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395069,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":366050,"Value":31625540,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":15,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":729,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":5400,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":5541,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":25678331,"Value":1051147327,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":48,"Value":4814,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4793905,"Value":998863809,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3585,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4793905,"Executor Run Time":48,"Executor CPU Time":25678331,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":366050,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395073,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":311940,"Value":31937480,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":876,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":6480,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":6589,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":27304550,"Value":1078451877,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":54,"Value":4868,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12246145,"Value":1011109954,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":3641,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":12246145,"Executor Run Time":54,"Executor CPU Time":27304550,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":311940,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395165,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":323898,"Value":32261378,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":21,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1023,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":7560,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":7637,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21689428,"Value":1100141305,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":77,"Value":4945,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4239884,"Value":1015349838,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3645,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4239884,"Executor Run Time":77,"Executor CPU Time":21689428,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":323898,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395201,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":301705,"Value":32563083,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":24,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":1167,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":8640,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":8728,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20826337,"Value":1120967642,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":76,"Value":5021,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4598966,"Value":1019948804,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3650,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4598966,"Executor Run Time":76,"Executor CPU Time":20826337,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":301705,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395225,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":319101,"Value":32882184,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1314,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":9720,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":9776,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21657558,"Value":1142625200,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":34,"Value":5055,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4010338,"Value":1023959142,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3654,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4010338,"Executor Run Time":34,"Executor CPU Time":21657558,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":319101,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395276,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":369513,"Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1461,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":10800,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":10824,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20585619,"Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":25,"Value":5080,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5860574,"Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":25,"Value":3679,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":25,"Executor Deserialize CPU Time":5860574,"Executor Run Time":25,"Executor CPU Time":20585619,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":369513,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Completion Time":1516300395279,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":5080,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10824,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":30,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":10800,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1461,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":5,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":3679,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395525,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52455999,"Value":52455999,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":95,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":23136577,"Value":23136577,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":82,"Value":82,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":82,"Executor Deserialize CPU Time":23136577,"Executor Run Time":95,"Executor CPU Time":52455999,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395576,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":13617615,"Value":66073614,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":29,"Value":124,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3469612,"Value":26606189,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":86,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3469612,"Executor Run Time":29,"Executor CPU Time":13617615,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395581,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":55540208,"Value":121613822,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":179,"Value":303,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22400065,"Value":49006254,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":78,"Value":164,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":78,"Executor Deserialize CPU Time":22400065,"Executor Run Time":179,"Executor CPU Time":55540208,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395593,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4536,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52311573,"Value":173925395,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":153,"Value":456,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":20519033,"Value":69525287,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":67,"Value":231,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":67,"Executor Deserialize CPU Time":20519033,"Executor Run Time":153,"Executor CPU Time":52311573,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395660,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":5670,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11294260,"Value":185219655,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":489,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3570887,"Value":73096174,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":235,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3570887,"Executor Run Time":33,"Executor CPU Time":11294260,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395669,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":6804,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":12983732,"Value":198203387,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":44,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3518757,"Value":76614931,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":239,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3518757,"Executor Run Time":44,"Executor CPU Time":12983732,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395674,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7938,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":14706240,"Value":212909627,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":64,"Value":597,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":7698059,"Value":84312990,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":21,"Value":260,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":21,"Executor Deserialize CPU Time":7698059,"Executor Run Time":64,"Executor CPU Time":14706240,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":52,"Value":52,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":195,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":292,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":9224,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91696783,"Value":304606410,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":221,"Value":818,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":24063461,"Value":108376451,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":150,"Value":410,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":150,"Executor Deserialize CPU Time":24063461,"Executor Run Time":221,"Executor CPU Time":91696783,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":52,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":20,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":107,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":244,"Value":439,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":243,"Value":535,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":5,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":5,"Value":11,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":10510,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91683507,"Value":396289917,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":289,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22106726,"Value":130483177,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":79,"Value":489,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":79,"Executor Deserialize CPU Time":22106726,"Executor Run Time":289,"Executor CPU Time":91683507,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":5,"Local Blocks Fetched":5,"Fetch Wait Time":107,"Remote Bytes Read":243,"Remote Bytes Read To Disk":0,"Local Bytes Read":244,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395728,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":634,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":827,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":13,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":17,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":17607810,"Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":1140,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2897647,"Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":491,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2897647,"Executor Run Time":33,"Executor CPU Time":17607810,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Completion Time":1516300395728,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":159,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":827,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":634,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":491,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":13,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":1140,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":17,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":30,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1516300395734,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1516300707938} diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index 91355f736290..a5bdc9579072 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -103,8 +103,11 @@ class DebugFilesystem extends LocalFileSystem { override def markSupported(): Boolean = wrapped.markSupported() override def close(): Unit = { - wrapped.close() - removeOpenStream(wrapped) + try { + wrapped.close() + } finally { + removeOpenStream(wrapped) + } } override def read(): Int = wrapped.read() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index e09d5f59817b..28ea0c6f0bdb 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -160,11 +160,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) - assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === - storageLevel.replication * data.getNumPartitions) - assert(cachedData.count === 1000) - assert(cachedData.count === 1000) - + assert(sc.getRDDStorageInfo.filter(_.id == cachedData.id).map(_.numCachedPartitions).sum === + data.getNumPartitions) // Get all the locations of the first partition and try to fetch the partitions // from those locations. val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 962945e5b6bb..896cd2e80aae 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits { */ object DriverWithoutCleanup { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a0cae5a9e011..3cfb0a9feb32 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import scala.collection.mutable +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics @@ -26,6 +28,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.ManualClock /** @@ -142,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get @@ -1050,6 +1086,66 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager) === Map.empty) } + test("SPARK-23365 Don't update target num executors when killing idle executors") { + val minExecutors = 1 + val initialExecutors = 1 + val maxExecutors = 2 + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) + .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms") + val mockAllocationClient = mock(classOf[ExecutorAllocationClient]) + val mockBMM = mock(classOf[BlockManagerMaster]) + val manager = new ExecutorAllocationManager( + mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM) + val clock = new ManualClock() + manager.setClock(clock) + + when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true) + // test setup -- job with 2 tasks, scale up to two executors + assert(numExecutorsTarget(manager) === 1) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty))) + manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2))) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 2) + val taskInfo0 = createTaskInfo(0, 0, "executor-1") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0)) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty))) + val taskInfo1 = createTaskInfo(1, 1, "executor-2") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1)) + assert(numExecutorsTarget(manager) === 2) + + // have one task finish -- we should adjust the target number of executors down + // but we should *not* kill any executors yet + manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 1) + verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any()) + + // now we cross the idle timeout for executor-1, so we kill it. the really important + // thing here is that we do *not* ask the executor allocation client to adjust the target + // number of executors down + when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false)) + .thenReturn(Seq("executor-1")) + clock.advance(3000) + schedule(manager) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 1) + // here's the important verify -- we did kill the executors, but did not adjust the target count + verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -1268,7 +1364,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = executorIds override def start(): Unit = sb.start() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index e9539dc73f6f..a441b9c8ab97 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec @@ -32,7 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.internal.config._ -import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -244,7 +245,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until testOutputCopies) { // Shift values by i so that they're different in the output val alteredOutput = testOutput.map(b => (b + i).toByte) - channel.write(ByteBuffer.wrap(alteredOutput)) + val buffer = ByteBuffer.wrap(alteredOutput) + while (buffer.hasRemaining) { + channel.write(buffer) + } } channel.close() file.close() @@ -593,4 +597,70 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { actualPartitionNum = 5, expectedPartitionNum = 2) } + + test("spark.files.ignoreMissingFiles should work both HadoopRDD and NewHadoopRDD") { + // "file not found" can happen both when getPartitions or compute in HadoopRDD/NewHadoopRDD, + // We test both cases here. + + val deletedPath = new Path(tempDir.getAbsolutePath, "test-data-1") + val fs = deletedPath.getFileSystem(new Configuration()) + fs.delete(deletedPath, true) + intercept[FileNotFoundException](fs.open(deletedPath)) + + def collectRDDAndDeleteFileBeforeCompute(newApi: Boolean): Array[_] = { + val dataPath = new Path(tempDir.getAbsolutePath, "test-data-2") + val writer = new OutputStreamWriter(new FileOutputStream(new File(dataPath.toString))) + writer.write("hello\n") + writer.write("world\n") + writer.close() + val rdd = if (newApi) { + sc.newAPIHadoopFile(dataPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]) + } else { + sc.textFile(dataPath.toString) + } + rdd.partitions + fs.delete(dataPath, true) + // Exception happens when initialize record reader in HadoopRDD/NewHadoopRDD.compute + // because partitions' info already cached. + rdd.collect() + } + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=false by default. + sc = new SparkContext("local", "test") + intercept[org.apache.hadoop.mapred.InvalidInputException] { + // Exception happens when HadoopRDD.getPartitions + sc.textFile(deletedPath.toString).collect() + } + + var e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(false) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + intercept[org.apache.hadoop.mapreduce.lib.input.InvalidInputException] { + // Exception happens when NewHadoopRDD.getPartitions + sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect + } + + e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(true) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + sc.stop() + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=true. + val conf = new SparkConf().set(IGNORE_MISSING_FILES, true) + sc = new SparkContext("local", "test", conf) + assert(sc.textFile(deletedPath.toString).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(false).isEmpty) + + assert(sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty) + } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 8d7be77f51fe..62824a5bec9d 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt // ID that's < 2 * numPartitions belongs to the first attempt of this stage. val taskContext = TaskContext.get() - val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L if (isFirstStageAttempt) { throw new FetchFailedException( SparkEnv.get.blockManager.blockManagerId, diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 8a77aea75a99..61da4138896c 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -26,7 +27,7 @@ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft override def afterEach() { try { resetSparkContext() + JobCancellationSuite.taskStartedSemaphore.drainPermits() + JobCancellationSuite.taskCancelledSemaphore.drainPermits() + JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits() + JobCancellationSuite.executionOfInterruptibleCounter.set(0) } finally { super.afterEach() } @@ -320,6 +325,67 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft f2.get() } + test("interruptible iterator of shuffle reader") { + // In this test case, we create a Spark job of two stages. The second stage is cancelled during + // execution and a counter is used to make sure that the corresponding tasks are indeed + // cancelled. + import JobCancellationSuite._ + sc = new SparkContext("local[2]", "test interruptible iterator") + + // Increase the number of elements to be proceeded to avoid this test being flaky. + val numElements = 10000 + val taskCompletedSem = new Semaphore(0) + + sc.addSparkListener(new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + // release taskCancelledSemaphore when cancelTasks event has been posted + if (stageCompleted.stageInfo.stageId == 1) { + taskCancelledSemaphore.release(numElements) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.stageId == 1) { // make sure tasks are completed + taskCompletedSem.release() + } + } + }) + + // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be + // interrupted by `InterruptibleIterator`. + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + + val f = sc.parallelize(1 to numElements).map { i => (i, i) } + .repartitionAndSortWithinPartitions(new HashPartitioner(1)) + .mapPartitions { iter => + taskStartedSemaphore.release() + iter + }.foreachAsync { x => + // Block this code from being executed, until the job get cancelled. In this case, if the + // source iterator is interruptible, the max number of increment should be under + // `numElements`. + taskCancelledSemaphore.acquire() + executionOfInterruptibleCounter.getAndIncrement() + } + + taskStartedSemaphore.acquire() + // Job is cancelled when: + // 1. task in reduce stage has been started, guaranteed by previous line. + // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until + // JobCancelled event is posted. + // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus + // partial of the inputs should not get processed (It's very unlikely that Spark can process + // 10000 elements between JobCancelled is posted and task is really killed). + f.cancel() + + val e = intercept[SparkException](f.get()).getCause + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + + // Make sure tasks are indeed completed. + taskCompletedSem.acquire() + assert(executionOfInterruptibleCounter.get() < numElements) + } + def testCount() { // Cancel before launching any tasks { @@ -381,7 +447,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft object JobCancellationSuite { + // To avoid any headaches, reset these global variables in the companion class's afterEach block val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) val twoJobsSharingStageSemaphore = new Semaphore(0) + val executionOfInterruptibleCounter = new AtomicInteger(0) } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d..21f481d47724 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0))) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 155ca17db726..9206b5debf4f 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -262,14 +262,11 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("defaultPartitioner") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) - val rdd2 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(10)) - val rdd3 = sc - .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) .partitionBy(new HashPartitioner(100)) - val rdd4 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(9)) val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) @@ -284,7 +281,42 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva assert(partitioner3.numPartitions == rdd3.getNumPartitions) assert(partitioner4.numPartitions == rdd3.getNumPartitions) assert(partitioner5.numPartitions == rdd4.getNumPartitions) + } + test("defaultPartitioner when defaultParallelism is set") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + val rdd6 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(3)) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + val partitioner6 = Partitioner.defaultPartitioner(rdd5, rdd5) + val partitioner7 = Partitioner.defaultPartitioner(rdd1, rdd6) + + assert(partitioner1.numPartitions == rdd2.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + assert(partitioner6.numPartitions == sc.defaultParallelism) + assert(partitioner7.numPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } } } diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala deleted file mode 100644 index 33270bec6247..000000000000 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -object SSLSampleConfigs { - val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File( - this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath - val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - - val enabledAlgorithms = - // A reasonable set of TLSv1.2 Oracle security provider suites - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "TLS_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + - // and their equivalent names in the IBM Security provider - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "SSL_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" - - def sparkSSLConfig(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", keyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - - def sparkSSLConfigUntrusted(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", untrustedKeyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - -} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index cf59265dd646..e357299770a2 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -370,51 +370,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions("user1") === false) } - test("ssl on setup") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val expectedAlgorithms = Set( - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "TLS_RSA_WITH_AES_256_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "SSL_RSA_WITH_AES_256_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === true) - - assert(securityManager.sslSocketFactory.isDefined === true) - assert(securityManager.hostnameVerifier.isDefined === true) - - assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - } - - test("ssl off setup") { - val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) - - System.setProperty("spark.ssl.configFile", file.getAbsolutePath) - val conf = new SparkConf() - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.sslSocketFactory.isDefined === false) - assert(securityManager.hostnameVerifier.isDefined === false) - } - test("missing secret authentication key") { val conf = new SparkConf().set("spark.authenticate", "true") val mgr = new SecurityManager(conf) @@ -440,23 +395,41 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } - test("secret key generation in yarn mode") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(SparkLauncher.SPARK_MASTER, "yarn") - val mgr = new SecurityManager(conf) - - UserGroupInformation.createUserForTesting("authTest", Array()).doAs( - new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - mgr.initializeAuth() - val creds = UserGroupInformation.getCurrentUser().getCredentials() - val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) - assert(secret != null) - assert(new String(secret, UTF_8) === mgr.getSecretKey()) + test("secret key generation") { + Seq( + ("yarn", true), + ("local", true), + ("local[*]", true), + ("local[1, 2]", true), + ("local-cluster[2, 1, 1024]", false), + ("invalid", false) + ).foreach { case (master, shouldGenerateSecret) => + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(SparkLauncher.SPARK_MASTER, master) + val mgr = new SecurityManager(conf) + + UserGroupInformation.createUserForTesting("authTest", Array()).doAs( + new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + if (shouldGenerateSecret) { + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + } else { + intercept[IllegalArgumentException] { + mgr.initializeAuth() + } + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } + } + } } - } - ) + ) + } } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0..ced5a06516f7 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b30bd74812b3..ce9f2be1c02d 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") val REASON = "You shall not pass" - val slices = 10 - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - if (SparkContextSuite.cancelStage) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + for (cancelWhat <- Seq("stage", "job")) { + // This countdown latch used to make sure stage or job canceled in listener + val latch = new CountDownLatch(1) + + val listener = cancelWhat match { + case "stage" => + new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sc.cancelStage(taskStart.stageId, REASON) + latch.countDown() + } } - sc.cancelStage(taskStart.stageId, REASON) - SparkContextSuite.cancelStage = false - SparkContextSuite.semaphore.release(slices) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - if (SparkContextSuite.cancelJob) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + case "job" => + new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sc.cancelJob(jobStart.jobId, REASON) + latch.countDown() + } } - sc.cancelJob(jobStart.jobId, REASON) - SparkContextSuite.cancelJob = false - SparkContextSuite.semaphore.release(slices) - } } - } - sc.addSparkListener(listener) - - for (cancelWhat <- Seq("stage", "job")) { - SparkContextSuite.semaphore.drainPermits() - SparkContextSuite.isTaskStarted = false - SparkContextSuite.cancelStage = (cancelWhat == "stage") - SparkContextSuite.cancelJob = (cancelWhat == "job") + sc.addSparkListener(listener) val ex = intercept[SparkException] { - sc.range(0, 10000L, numSlices = slices).mapPartitions { x => - SparkContextSuite.isTaskStarted = true - // Block waiting for the listener to cancel the stage or job. - SparkContextSuite.semaphore.acquire() + sc.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } x }.count() } @@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } + latch.await(20, TimeUnit.SECONDS) eventually(timeout(20.seconds)) { assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sc.removeSparkListener(listener) } } @@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } object SparkContextSuite { - @volatile var cancelJob = false - @volatile var cancelStage = false @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 18077c08c9dc..31289026b002 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -27,19 +27,54 @@ import org.apache.spark.util.AccumulatorContext /** * Base abstract class for all unit tests in Spark for handling common functionality. + * + * Thread audit happens normally here automatically when a new test suite created. + * The only prerequisite for that is that the test class must extend [[SparkFunSuite]]. + * + * It is possible to override the default thread audit behavior by setting enableAutoThreadAudit + * to false and manually calling the audit methods, if desired. For example: + * + * class MyTestSuite extends SparkFunSuite { + * + * override val enableAutoThreadAudit = false + * + * protected override def beforeAll(): Unit = { + * doThreadPreAudit() + * super.beforeAll() + * } + * + * protected override def afterAll(): Unit = { + * super.afterAll() + * doThreadPostAudit() + * } + * } */ abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll + with ThreadAudit with Logging { // scalastyle:on + protected val enableAutoThreadAudit = true + + protected override def beforeAll(): Unit = { + System.setProperty("spark.testing", "true") + if (enableAutoThreadAudit) { + doThreadPreAudit() + } + super.beforeAll() + } + protected override def afterAll(): Unit = { try { // Avoid leaking map entries in tests that use accumulators without SparkContext AccumulatorContext.clear() } finally { super.afterAll() + if (enableAutoThreadAudit) { + doThreadPostAudit() + } } } diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala new file mode 100644 index 000000000000..b3cea9de8f30 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging + +/** + * Thread audit for test suites. + */ +trait ThreadAudit extends Logging { + + val threadWhiteList = Set( + /** + * Netty related internal threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "netty.*", + + /** + * Netty related internal threads. + * A Single-thread singleton EventExecutor inside netty which creates such threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "globalEventExecutor.*", + + /** + * Netty related internal threads. + * Checks if a thread is alive periodically and runs a task when a thread dies. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "threadDeathWatcher.*", + + /** + * During [[SparkContext]] creation [[org.apache.spark.rpc.netty.NettyRpcEnv]] + * creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "rpc-client.*", + "rpc-server.*", + + /** + * During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "shuffle-client.*", + "shuffle-server.*" + ) + private var threadNamesSnapshot: Set[String] = Set.empty + + protected def doThreadPreAudit(): Unit = { + threadNamesSnapshot = runningThreadNames() + } + + protected def doThreadPostAudit(): Unit = { + val shortSuiteName = this.getClass.getName.replaceAll("org.apache.spark", "o.a.s") + + if (threadNamesSnapshot.nonEmpty) { + val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot) + .filterNot { s => threadWhiteList.exists(s.matches(_)) } + if (remainingThreadNames.nonEmpty) { + logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " + + s"thread names: ${remainingThreadNames.mkString(", ")} =====\n") + } + } else { + logWarning("\n\n===== THREAD AUDIT POST ACTION CALLED " + + s"WITHOUT PRE ACTION IN SUITE $shortSuiteName =====\n") + } + } + + private def runningThreadNames(): Set[String] = { + Thread.getAllStackTraces.keySet().asScala.map(_.getName).toSet + } +} diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 159629825c67..9ad2e9a5e74a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(broadcast.value.sum === 10) } + test("One broadcast value instance per executor") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + + test("One broadcast value instance per executor when memory is constrained") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 32dd3ecc2f02..ef947eb07464 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -66,7 +66,6 @@ class RPackageUtilsSuite override def beforeEach(): Unit = { super.beforeEach() - System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 27dd43533234..7451e07b25a1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -35,12 +35,14 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.TestUtils import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -105,9 +107,13 @@ class SparkSubmitSuite // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val emptyIvySettings = File.createTempFile("ivy", ".xml") + FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + + private val submit = new SparkSubmit() + override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } // scalastyle:off println @@ -125,13 +131,16 @@ class SparkSubmitSuite } test("handle binary specified but not class") { - testPrematureExit(Array("foo.jar"), "No main class") + val jar = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + testPrematureExit(Array(jar.toString()), "No main class") } test("handles arguments with --key=val") { val clArgs = Seq( "--jars=one.jar,two.jar,three.jar", - "--name=myApp") + "--name=myApp", + "--class=org.FooBar", + SparkLauncher.NO_RESOURCE) val appArgs = new SparkSubmitArguments(clArgs) appArgs.jars should include regex (".*one.jar,.*two.jar,.*three.jar") appArgs.name should be ("myApp") @@ -179,7 +188,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") conf.get("spark.submit.deployMode") should be ("client") @@ -189,11 +198,11 @@ class SparkSubmitSuite "--master", "yarn", "--deploy-mode", "cluster", "--conf", "spark.submit.deployMode=client", - "-class", "org.SomeClass", + "--class", "org.SomeClass", "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") conf1.get("spark.submit.deployMode") should be ("cluster") @@ -207,7 +216,7 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") conf2.get("spark.submit.deployMode") should be ("client") } @@ -230,7 +239,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -273,7 +282,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -319,7 +328,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -356,7 +365,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -378,7 +387,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -400,7 +409,7 @@ class SparkSubmitSuite "/home/thejar.jar", "arg1") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar")) @@ -425,7 +434,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.executor.memory") should be ("5g") conf.get("spark.master") should be ("yarn") conf.get("spark.submit.deployMode") should be ("cluster") @@ -438,12 +447,12 @@ class SparkSubmitSuite val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } @@ -520,6 +529,7 @@ class SparkSubmitSuite "--repositories", repo, "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -530,7 +540,6 @@ class SparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), @@ -540,6 +549,7 @@ class SparkSubmitSuite "--conf", s"spark.jars.repositories=$repo", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -550,7 +560,6 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -563,6 +572,7 @@ class SparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--packages", main.toString, "--repositories", repo, + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) @@ -573,7 +583,6 @@ class SparkSubmitSuite test("include an external JAR in SparkR") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) @@ -606,10 +615,13 @@ class SparkSubmitSuite } test("resolves command line argument paths correctly") { - val jars = "/jar1,/jar2" // --jars - val files = "local:/file1,file2" // --files - val archives = "file:/archive1,archive2" // --archives - val pyFiles = "py-file1,py-file2" // --py-files + val dir = Utils.createTempDir() + val archive = Paths.get(dir.toPath.toString, "single.zip") + Files.createFile(archive) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" // Test jars and files val clArgs = Seq( @@ -619,7 +631,7 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -634,11 +646,12 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) - appArgs2.archives should be (Utils.resolveURIs(archives)) + appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.archives") should fullyMatch regex + ("file:/archive1,file:.*#archive3") // Test python files val clArgs3 = Seq( @@ -649,7 +662,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -657,6 +670,29 @@ class SparkSubmitSuite conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } + test("ambiguous archive mapping results in error message") { + val dir = Utils.createTempDir() + val archive1 = Paths.get(dir.toPath.toString, "first.zip") + val archive2 = Paths.get(dir.toPath.toString, "second.zip") + Files.createFile(archive1) + Files.createFile(archive2) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + + testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + } + test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files @@ -678,7 +714,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) conf.get("spark.files") should be(Utils.resolveURIs(files)) @@ -695,7 +731,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -710,7 +746,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -727,7 +763,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -748,32 +784,20 @@ class SparkSubmitSuite } test("SPARK_CONF_DIR overrides spark-defaults.conf") { - forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + forConfDir(Map("spark.executor.memory" -> "3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", unusedJar.toString) - val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + val appArgs = new SparkSubmitArguments(args, env = Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("3g") } } - test("comma separated list of files are unioned correctly") { - val left = Option("/tmp/a.jar,/tmp/b.jar") - val right = Option("/tmp/c.jar,/tmp/a.jar") - val emptyString = Option("") - Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar")) - Utils.unionFileLists(emptyString, emptyString) should be (Set.empty) - Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) - } - test("support glob path") { val tmpJarDir = Utils.createTempDir() val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) @@ -791,6 +815,9 @@ class SparkSubmitSuite val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + val tempPyFile = File.createTempFile("tmpApp", ".py") + tempPyFile.deleteOnExit() + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", @@ -800,10 +827,10 @@ class SparkSubmitSuite "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", - jar2.toString) + tempPyFile.toURI().toString()) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) conf.get("spark.yarn.dist.files").split(",").toSet should be @@ -929,7 +956,7 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. conf.get("spark.yarn.dist.jars") should be (tmpJarPath) @@ -944,25 +971,28 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) } test("force download from blacklisted schemes") { - testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) } - private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, - supportMockHttpFs: Boolean): Unit = { + private def testRemoteResources( + enableHttpFs: Boolean, + blacklistHttpFs: Boolean): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - if (supportMockHttpFs) { + if (enableHttpFs) { hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) - hadoopConf.set("fs.http.impl.disable.cache", "true") + } else { + hadoopConf.set("fs.http.impl", getClass().getName() + ".DoesNotExist") } + hadoopConf.set("fs.http.impl.disable.cache", "true") val tmpDir = Utils.createTempDir() val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) @@ -971,30 +1001,29 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + val forceDownloadArgs = if (blacklistHttpFs) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + } else { + Nil + } + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", "--master", "yarn", "--deploy-mode", "client", - "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", - s"s3a://$mainResource" - ) ++ ( - if (isHttpSchemeBlacklisted) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") - } else { - Nil - } - ) + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath" + ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) val jars = conf.get("spark.yarn.dist.jars").split(",").toSet // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) - if (supportMockHttpFs) { + if (enableHttpFs && !blacklistHttpFs) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) @@ -1038,11 +1067,12 @@ class SparkSubmitSuite "hello") val exception = intercept[SparkException] { - SparkSubmit.main(args) + submit.doSubmit(args) } assert(exception.getMessage() === "hello") } + } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1077,7 +1107,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits { object JarCreationTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -1101,7 +1131,7 @@ object JarCreationTest extends Logging { object SimpleApplicationTest { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index bf7480d79f8a..27cc47496c80 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite syncExecutors(sc) sc.schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = false, force) + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false, + force) case _ => fail("expected coarse grained scheduler") } } @@ -610,7 +611,7 @@ class StandaloneDynamicAllocationSuite * we submit a request to kill them. This must be called before each kill request. */ private def syncExecutors(sc: SparkContext): Unit = { - val driverExecutors = sc.getExecutorStorageStatus + val driverExecutors = sc.env.blockManager.master.getStorageStatus .map(_.blockManagerId.executorId) .filter { _ != SparkContext.DRIVER_IDENTIFIER} val masterExecutors = getExecutorIds(sc) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 84ee01c7f5aa..77b239489d48 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -149,8 +149,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { var mergeApplicationListingCall = 0 - override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - super.mergeApplicationListing(fileStatus) + override protected def mergeApplicationListing( + fileStatus: FileStatus, + lastSeen: Long, + enableSkipToEnd: Boolean): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen, enableSkipToEnd) mergeApplicationListingCall += 1 } } @@ -254,14 +257,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => - list should not be (null) list.size should be (1) list.head.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt1, true, None, + writeFile(app2Attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -380,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) writeFile(log, true, None, SparkListenerApplicationStart( - "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), - SparkListenerApplicationEnd(5001 * i) + "downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001L * i) ) log } @@ -647,8 +649,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Add more info to the app log, and trigger the provider to update things. writeFile(appLog, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + SparkListenerJobStart(0, 1L, Nil, null) ) provider.checkForLogs() @@ -663,6 +664,160 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc freshUI.get.ui.store.job(0) } + test("clean up stale app information") { + val storeDir = Utils.createTempDir() + val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + val clock = new ManualClock() + val provider = spy(new FsHistoryProvider(conf, clock)) + val appId = "new1" + + // Write logs for two app attempts. + clock.advance(1) + val attempt1 = newLogFile(appId, Some("1"), inProgress = false) + writeFile(attempt1, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + val attempt2 = newLogFile(appId, Some("2"), inProgress = false) + writeFile(attempt2, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 2) + } + + // Load the app's UI. + val ui = provider.getAppUI(appId, Some("1")) + assert(ui.isDefined) + + // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since + // attempt 2 still exists, listing data should be there. + clock.advance(1) + attempt1.delete() + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 1) + } + assert(!ui.get.valid) + assert(provider.getAppUI(appId, None) === None) + + // Delete the second attempt's log file. Now everything should go away. + clock.advance(1) + attempt2.delete() + updateAndCheck(provider) { list => + assert(list.isEmpty) + } + } + + test("SPARK-21571: clean up removes invalid history files") { + val clock = new ManualClock() + val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") + val provider = new FsHistoryProvider(conf, clock) + + // Create 0-byte size inprogress and complete files + var logCount = 0 + var validLogCount = 0 + + val emptyInProgress = newLogFile("emptyInprogressLogFile", None, inProgress = true) + emptyInProgress.createNewFile() + emptyInProgress.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val slowApp = newLogFile("slowApp", None, inProgress = true) + slowApp.createNewFile() + slowApp.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val emptyFinished = newLogFile("emptyFinishedLogFile", None, inProgress = false) + emptyFinished.createNewFile() + emptyFinished.setLastModified(clock.getTimeMillis()) + logCount += 1 + + // Create an incomplete log file, has an end record but no start record. + val corrupt = newLogFile("nonEmptyCorruptLogFile", None, inProgress = false) + writeFile(corrupt, true, None, SparkListenerApplicationEnd(0)) + corrupt.setLastModified(clock.getTimeMillis()) + logCount += 1 + + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Move the clock forward 1 day and scan the files again. They should still be there. + clock.advance(TimeUnit.DAYS.toMillis(1)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Update the slow app to contain valid info. Code should detect the change and not clean + // it up. + writeFile(slowApp, true, None, + SparkListenerApplicationStart(slowApp.getName(), Some(slowApp.getName()), 1L, "test", None)) + slowApp.setLastModified(clock.getTimeMillis()) + validLogCount += 1 + + // Move the clock forward another 2 days and scan the files again. This time the cleaner should + // pick up the invalid files and get rid of them. + clock.advance(TimeUnit.DAYS.toMillis(2)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === validLogCount) + } + + test("always find end event for finished apps") { + // Create a log file where the end event is before the configure chunk to be reparsed at + // the end of the file. The correct listing should still be generated. + val log = newLogFile("end-event-test", None, inProgress = false) + writeFile(log, true, None, + Seq( + SparkListenerApplicationStart("end-event-test", Some("end-event-test"), 1L, "test", None), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> Seq.empty, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(5L) + ) ++ (1 to 1000).map { i => SparkListenerJobStart(i, i, Nil) }: _*) + + val conf = createTestConf().set(END_EVENT_REPARSE_CHUNK_SIZE.key, s"1k") + val provider = new FsHistoryProvider(conf) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).attempts.size === 1) + assert(list(0).attempts(0).completed) + } + } + + test("parse event logs with optimizations off") { + val conf = createTestConf() + .set(END_EVENT_REPARSE_CHUNK_SIZE, 0L) + .set(FAST_IN_PROGRESS_PARSING, false) + val provider = new FsHistoryProvider(conf) + + val complete = newLogFile("complete", None, inProgress = false) + writeFile(complete, true, None, + SparkListenerApplicationStart("complete", Some("complete"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + val incomplete = newLogFile("incomplete", None, inProgress = true) + writeFile(incomplete, true, None, + SparkListenerApplicationStart("incomplete", Some("incomplete"), 1L, "test", None) + ) + + updateAndCheck(provider) { list => + list.size should be (2) + list.count(_.attempts.head.completed) should be (1) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -706,7 +861,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private def createTestConf(inMemory: Boolean = false): SparkConf = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set(EVENT_LOG_DIR, testDir.getAbsolutePath()) + .set(FAST_IN_PROGRESS_PARSING, true) if (!inMemory) { conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) @@ -739,4 +895,3 @@ class TestGroupsMappingProvider extends GroupMappingServiceProvider { mappings.get(username).map(Set(_)).getOrElse(Set.empty) } } - diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3738f85da583..a871b1c71783 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -156,6 +156,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "applications/local-1426533911241/1/stages/0/0/taskList", "stage task list from multi-attempt app json(2)" -> "applications/local-1426533911241/2/stages/0/0/taskList", + "blacklisting for stage" -> "applications/app-20180109111548-0000/stages/0/0", + "blacklisting node for stage" -> "applications/application_1516285256255_0012/stages/0/0", "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", @@ -294,6 +296,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) @@ -486,7 +493,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers json match { case JNothing => Seq() case apps: JArray => - apps.filter(app => { + apps.children.filter(app => { (app \ "attempts") match { case attempts: JArray => val state = (attempts.children.head \ "completed").asInstanceOf[JBool] @@ -564,7 +571,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit() + ShutdownHookManager.registerShutdownDeleteDir(logDir) } test("ui and api authorization checks") { diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index e505bc018857..54c168a8218f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,7 +445,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index eeffc36070b4..2849a10a2c81 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.security +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials @@ -110,7 +111,64 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { creds.getAllTokens.size should be (0) } + test("SPARK-23209: obtain tokens when Hive classes are not available") { + // This test needs a custom class loader to hide Hive classes which are in the classpath. + // Because the manager code loads the Hive provider directly instead of using reflection, we + // need to drive the test through the custom class loader so a new copy that cannot find + // Hive classes is loaded. + val currentLoader = Thread.currentThread().getContextClassLoader() + val noHive = new ClassLoader() { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + if (name.startsWith("org.apache.hive") || name.startsWith("org.apache.hadoop.hive")) { + throw new ClassNotFoundException(name) + } + + if (name.startsWith("java") || name.startsWith("scala")) { + currentLoader.loadClass(name) + } else { + val classFileName = name.replaceAll("\\.", "/") + ".class" + val in = currentLoader.getResourceAsStream(classFileName) + if (in != null) { + val bytes = IOUtils.toByteArray(in) + defineClass(name, bytes, 0, bytes.length) + } else { + throw new ClassNotFoundException(name) + } + } + } + } + + try { + Thread.currentThread().setContextClassLoader(noHive) + val test = noHive.loadClass(NoHiveTest.getClass.getName().stripSuffix("$")) + test.getMethod("runTest").invoke(null) + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + } + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { Set(FileSystem.get(hadoopConf)) } } + +/** Test code for SPARK-23209 to avoid using too much reflection above. */ +private object NoHiveTest extends Matchers { + + def runTest(): Unit = { + try { + val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(), + _ => Set()) + manager.getServiceDelegationTokenProvider("hive") should be (None) + } catch { + case e: Throwable => + // Throw a better exception in case the test fails, since there may be a lot of nesting. + var cause = e + while (cause.getCause() != null) { + cause = cause.getCause() + } + throw cause + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 105a178f2d94..1a7bebe2c53c 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) + } + + test("SPARK-23816: interrupts are not masked by a FetchFailure") { + // If killing the task causes a fetch failure, we still treat it as a task that was killed, + // as the fetch failure could easily be caused by interrupting the thread. + val (failReason, _) = testFetchFailureHandling(false) + assert(failReason.isInstanceOf[TaskKilled]) + } + + /** + * Helper for testing some cases where a FetchFailure should *not* get sent back, because its + * superceded by another error, either an OOM or intentionally killing a task. + * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the + * FetchFailure + */ + private def testFetchFailureHandling( + oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. + // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task + // does not represent a real fetch failure. val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat - // the fetch failure as a false positive, and just do normal OOM handling. + // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We + // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + if (!oom) { + // we are trying to setup a case where a task is killed after a fetch failure -- this + // is just a helper to coordinate between the task thread and this thread that will + // kill the task + ExecutorSuiteHelper.latches = new ExecutorSuiteHelper() + } + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serTask = serializer.serialize(task) val taskDescription = createFakeTaskDescription(serTask) - val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) - // make sure the task failure just looks like a OOM, not a fetch failure - assert(failReason.isInstanceOf[ExceptionFailure]) - val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) - verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) - assert(exceptionCaptor.getAllValues.size === 1) - assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) - } + runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom) + } test("Gracefully handle error in task deserialization") { val conf = new SparkConf @@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null + val timedOut = new AtomicBoolean(false) try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) + if (killTask) { + val killingThread = new Thread("kill-task") { + override def run(): Unit = { + // wait to kill the task until it has thrown a fetch failure + if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) { + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } else { + timedOut.set(true) + } + } + } + killingThread.start() + } eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } + assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop() @@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend) .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture()) // first statusUpdate for RUNNING has empty data assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting @@ -321,7 +364,8 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, val input: FetchFailureThrowingRDD, - throwOOM: Boolean) extends RDD[Int](input) { + throwOOM: Boolean, + interrupt: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { @@ -330,6 +374,15 @@ class FetchFailureHidingRDD( case t: Throwable => if (throwOOM) { throw new OutOfMemoryError("OOM while handling another exception") + } else if (interrupt) { + // make sure our test is setup correctly + assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) + // signal our test is ready for the task to get killed + ExecutorSuiteHelper.latches.latch1.countDown() + // then wait for another thread in the test to kill the task -- this latch + // is never actually decremented, we just wait to get killed. + ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS) + throw new IllegalStateException("timed out waiting to be interrupted") } else { throw new RuntimeException("User Exception that hides the original exception", t) } @@ -352,6 +405,11 @@ private class ExecutorSuiteHelper { @volatile var testFailedReason: TaskFailedReason = _ } +// helper for coordinating killing tasks +private object ExecutorSuiteHelper { + var latches: ExecutorSuiteHelper = null +} + private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { def writeExternal(out: ObjectOutput): Unit = {} def readExternal(in: ObjectInput): Unit = { diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala new file mode 100644 index 000000000000..2bd32fc927e2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for instantiation of FileCommitProtocol implementations. + */ +class FileCommitProtocolInstantiationSuite extends SparkFunSuite { + + test("Dynamic partitions require appropriate constructor") { + + // you cannot instantiate a two-arg client with dynamic partitions + // enabled. + val ex = intercept[IllegalArgumentException] { + instantiateClassic(true) + } + // check the contents of the message and rethrow if unexpected. + // this preserves the stack trace of the unexpected + // exception. + if (!ex.toString.contains("Dynamic Partition Overwrite")) { + fail(s"Wrong text in caught exception $ex", ex) + } + } + + test("Standard partitions work with classic constructor") { + instantiateClassic(false) + } + + test("Three arg constructors have priority") { + assert(3 == instantiateNew(false).argCount, + "Wrong constructor argument count") + } + + test("Three arg constructors have priority when dynamic") { + assert(3 == instantiateNew(true).argCount, + "Wrong constructor argument count") + } + + test("The protocol must be of the correct class") { + intercept[ClassCastException] { + FileCommitProtocol.instantiate( + classOf[Other].getCanonicalName, + "job", + "path", + false) + } + } + + test("If there is no matching constructor, class hierarchy is irrelevant") { + intercept[NoSuchMethodException] { + FileCommitProtocol.instantiate( + classOf[NoMatchingArgs].getCanonicalName, + "job", + "path", + false) + } + } + + /** + * Create a classic two-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[ClassicConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[ClassicConstructorCommitProtocol] + } + + /** + * Create a three-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateNew( + dynamic: Boolean): FullConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[FullConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[FullConstructorCommitProtocol] + } + +} + +/** + * This protocol implementation does not have the new three-arg + * constructor. + */ +private class ClassicConstructorCommitProtocol(arg1: String, arg2: String) + extends HadoopMapReduceCommitProtocol(arg1, arg2) { +} + +/** + * This protocol implementation does have the new three-arg constructor + * alongside the original, and a 4 arg one for completeness. + * The final value of the real constructor is the number of arguments + * used in the 2- and 3- constructor, for test assertions. + */ +private class FullConstructorCommitProtocol( + arg1: String, + arg2: String, + b: Boolean, + val argCount: Int) + extends HadoopMapReduceCommitProtocol(arg1, arg2, b) { + + def this(arg1: String, arg2: String) = { + this(arg1, arg2, false, 2) + } + + def this(arg1: String, arg2: String, b: Boolean) = { + this(arg1, arg2, false, 3) + } +} + +/** + * This has the 2-arity constructor, but isn't the right class. + */ +private class Other(arg1: String, arg2: String) { + +} + +/** + * This has no matching arguments as well as being the wrong class. + */ +private class NoMatchingArgs() { + +} + diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b049..2107559572d7 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc24..dcf89e4f75ac 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index f7bc3725d727..78423ee68a0e 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,6 +80,7 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests + // if `spark.testing` is true, // the default value for `spark.port.maxRetries` is 100 under test actualPort should be <= (expectedPort + 100) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index a39e0469272f..47af5c3320dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -322,8 +322,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } // See SPARK-22465 - test("cogroup between multiple RDD" + - " with number of partitions similar in order of magnitude") { + test("cogroup between multiple RDD with number of partitions similar in order of magnitude") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) val rdd2 = sc .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) @@ -332,6 +331,48 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.getNumPartitions == rdd2.getNumPartitions) } + test("cogroup between multiple RDD when defaultParallelism is set without proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)), 10) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set with proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set; with huge number of " + + "partitions in upstream RDDs") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462..191c61250ce2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -1129,6 +1129,35 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { }.collect() } + test("SPARK-23496: order of input partitions can result in severe skew in coalesce") { + val numInputPartitions = 100 + val numCoalescedPartitions = 50 + val locations = Array("locA", "locB") + + val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions) + assert(inputRDD.getNumPartitions == numInputPartitions) + + val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => + if (p.index < numCoalescedPartitions) { + Seq(locations(0)) + } else { + Seq(locations(1)) + } + }) + val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions) + + val numPartsPerLocation = coalescedRDD + .getPartitions + .map(coalescedRDD.getPreferredLocations(_).head) + .groupBy(identity) + .mapValues(_.size) + + // Make sure the coalesced partitions are distributed fairly evenly between the two locations. + // This should not become flaky since the DefaultPartitionsCoalescer uses a fixed seed. + assert(numPartsPerLocation(locations(0)) > 0.4 * numCoalescedPartitions) + assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because @@ -1210,3 +1239,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria groups.toArray } } + +/** Alters the preferred locations of the parent RDD using provided function. */ +class LocationPrefRDD[T: ClassTag]( + @transient var prev: RDD[T], + val locationPicker: Partition => Seq[String]) extends RDD[T](prev) { + override protected def getPartitions: Array[Partition] = prev.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = + null.asInstanceOf[Iterator[T]] + + override def getPreferredLocations(partition: Partition): Seq[String] = + locationPicker(partition) +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index cd1b7a9e5ab1..96c8404327e2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -92,7 +92,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } def createTaskSetBlacklist(stageId: Int = 0): TaskSetBlacklist = { - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(listenerBusMock, conf, stageId, stageAttemptId = 0, clock = clock) } test("executors can be blacklisted with only a few failures per stage") { @@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1) // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole @@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) - verify(allocationClientMock).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutors(Seq("2"), false, false, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -571,16 +571,19 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M conf.set(config.BLACKLIST_KILL_ENABLED, false) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. conf.set(config.BLACKLIST_KILL_ENABLED, true) blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) clock.advance(1000) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) verify(allocationClientMock, never).killExecutorsOnHost(any()) assert(blacklist.executorIdToBlacklistStatus.contains("1")) @@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) // Enable external shuffle service to see if all the executors on this node will be killed. conf.set(config.SHUFFLE_SERVICE_ENABLED, true) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d812b5bd92c1..8b6ec37625ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2146,6 +2146,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Trigger mapstage's job listener in submitMissingTasks") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + + // Complete the stage0. + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting stage1, trigger a fetch failure. + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + // Stage1 listener should not have a result yet + assert(listener2.results.size === 0) + + // Speculative task succeeded in stage1. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostD", rdd2.partitions.length))) + // stage1 listener still should not have a result, though there's no missing partitions + // in it. Because stage1 has been failed and is not inside `runningStages` at this moment. + assert(listener2.results.size === 0) + + // Stage0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // After stage0 is finished, stage1 will be submitted and found there is no missing + // partitions in it. Then listener got triggered. + assert(listener2.results.size === 1) + assertDataStructuresEmpty() + } + /** * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 73e7b3fe8c1d..e24d550a6266 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -47,7 +47,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Simple replay") { - val logFilePath = Utils.getFilePath(testDir, "events.txt") + val logFilePath = getFilePath(testDir, "events.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, @@ -97,7 +97,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // scalastyle:on println } - val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress") + val logFilePath = getFilePath(testDir, "events.lz4.inprogress") val bytes = buffered.toByteArray Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream => fstream.write(bytes, 0, buffered.size) @@ -129,7 +129,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Replay incompatible event log") { - val logFilePath = Utils.getFilePath(testDir, "incompatible.txt") + val logFilePath = getFilePath(testDir, "incompatible.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Incompatible App", None, @@ -226,6 +226,12 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } } + private def getFilePath(dir: File, fileName: String): Path = { + assert(dir.isDirectory) + val path = new File(dir, fileName).getAbsolutePath + new Path(path) + } + /** * A simple listener that buffers all the events it receives. * diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f..da6ecb82c7e4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085..aa9c36c0aaac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7be..33f2ea1c94e7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 18981d5be2f9..6e2709dbe1e8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -16,18 +16,32 @@ */ package org.apache.spark.scheduler +import org.mockito.Matchers.isA +import org.mockito.Mockito.{never, verify} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config -import org.apache.spark.util.{ManualClock, SystemClock} +import org.apache.spark.util.ManualClock + +class TaskSetBlacklistSuite extends SparkFunSuite with BeforeAndAfterEach with MockitoSugar { -class TaskSetBlacklistSuite extends SparkFunSuite { + private var listenerBusMock: LiveListenerBus = _ + + override def beforeEach(): Unit = { + listenerBusMock = mock[LiveListenerBus] + super.beforeEach() + } test("Blacklisting tasks, executors, and nodes") { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val clock = new ManualClock + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) clock.setTime(0) // We will mark task 0 & 1 failed on both executor 1 & 2. // We should blacklist all executors on that host, for all tasks for the stage. Note the API @@ -46,27 +60,53 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val shouldBeBlacklisted = (executor == "exec1" && index == 0) assert(taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === shouldBeBlacklisted) } + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerExecutorBlacklistedForStage])) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec2", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(0, "hostA", 2, 0, attemptId)) + // Make sure the blacklist has the correct per-task && per-executor responses, over a wider // range of inputs. for { @@ -81,6 +121,10 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // intentional, it keeps it fast and is sufficient for usage in the scheduler. taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === (badExec && badIndex)) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet(executor) === badExec) + if (badExec) { + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, executor, 2, 0, attemptId)) + } } } assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -110,7 +154,14 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_TASK_ATTEMPTS_PER_NODE, 3) .set(config.MAX_FAILURES_PER_EXEC_STAGE, 2) .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + + var time = 0 + clock.setTime(time) // Fail a task twice on hostA, exec:1 taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") @@ -118,37 +169,75 @@ class TaskSetBlacklistSuite extends SparkFunSuite { "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail the same task once more on hostA, exec:2 + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 2, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 3, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 4, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "3", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 3, 0, attemptId)) } test("only blacklist nodes for the task set when all the blacklisted executors are all on " + @@ -157,22 +246,42 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // lead to any node blacklisting val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + var time = 0 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostB")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2ce81ae27daf..ca6a7e5db3b1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -683,7 +683,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val conf = new SparkConf().set("spark.speculation", "true") sc = new SparkContext("local", "test", conf) - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { override def killTask( taskId: Long, @@ -709,6 +709,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val singleTask = new ShuffleMapTask(0, 0, null, new Partition { @@ -754,7 +755,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation", "true") var killTaskCalled = false - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) sched.initialize(new FakeSchedulerBackend() { override def killTask( @@ -789,6 +790,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, @@ -1183,6 +1185,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler.stop() sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdb..2d8a83c6fabe 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index d21ce73f4021..4ce379b76b55 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -64,6 +64,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } test("commit shuffle files multiple times") { + val shuffleId = 1 + val mapId = 2 + val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" val resolver = new IndexShuffleBlockResolver(conf, blockManager) val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) @@ -73,9 +76,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) - val dataFile = resolver.getDataFile(1, 2) + val indexFile = new File(tempDir.getAbsolutePath, idxName) + val dataFile = resolver.getDataFile(shuffleId, mapId) + + assert(indexFile.exists()) + assert(indexFile.length() === (lengths.length + 1) * 8) assert(dataFile.exists()) assert(dataFile.length() === 30) assert(!dataTmp.exists()) @@ -89,7 +96,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + + assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 30) @@ -97,18 +106,27 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa // The dataFile should be the previous one val firstByte = new Array[Byte](1) - val in = new FileInputStream(dataFile) + val dataIn = new FileInputStream(dataFile) Utils.tryWithSafeFinally { - in.read(firstByte) + dataIn.read(firstByte) } { - in.close() + dataIn.close() } assert(firstByte(0) === 0) + // The index file should not change + val indexIn = new DataInputStream(new FileInputStream(indexFile)) + Utils.tryWithSafeFinally { + indexIn.readLong() // the first offset is always 0 + assert(indexIn.readLong() === 10, "The index file should not change") + } { + indexIn.close() + } + // remove data file dataFile.delete() - val lengths3 = Array[Long](10, 10, 15) + val lengths3 = Array[Long](7, 10, 15, 3) val dataTmp3 = File.createTempFile("shuffle", null, tempDir) val out3 = new FileOutputStream(dataTmp3) Utils.tryWithSafeFinally { @@ -117,20 +135,29 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 35) - assert(!dataTmp2.exists()) + assert(!dataTmp3.exists()) - // The dataFile should be the previous one - val firstByte2 = new Array[Byte](1) - val in2 = new FileInputStream(dataFile) + // The dataFile should be the new one, since we deleted the dataFile from the first attempt + val dataIn2 = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + dataIn2.read(firstByte) + } { + dataIn2.close() + } + assert(firstByte(0) === 2) + + // The index file should be updated, since we deleted the dataFile from the first attempt + val indexIn2 = new DataInputStream(new FileInputStream(indexFile)) Utils.tryWithSafeFinally { - in2.read(firstByte2) + indexIn2.readLong() // the first offset is always 0 + assert(indexIn2.readLong() === 7, "The index file should be updated") } { - in2.close() + indexIn2.close() } - assert(firstByte2(0) === 2) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 55cebe7c8b6a..f29dac965c80 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // We support serialized shuffle if we do not need to do map-side aggregation + assert(canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) } test("unsupported shuffle dependencies for serialized shuffle") { @@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform aggregation - assert(!canUseSerializedShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = None, - aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), - mapSideCombine = false - ))) + // We do not support serialized shuffle if we need to do map-side aggregation assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 997c7de8dd02..1cd71955ad4d 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -195,7 +195,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val s1Tasks = createTasks(4, execIds) s1Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, + stages.head.attemptNumber, + task)) } assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) @@ -211,55 +213,100 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.taskId === task.taskId) + assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) assert(wrapper.stageAttemptId === stages.head.attemptId) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) - - val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, - -1L: JLong) - assert(Arrays.equals(wrapper.runtime, runtime)) - - assert(wrapper.info.index === task.index) - assert(wrapper.info.attempt === task.attemptNumber) - assert(wrapper.info.launchTime === new Date(task.launchTime)) - assert(wrapper.info.executorId === task.executorId) - assert(wrapper.info.host === task.host) - assert(wrapper.info.status === task.status) - assert(wrapper.info.taskLocality === task.taskLocality.toString()) - assert(wrapper.info.speculative === task.speculative) + assert(wrapper.index === task.index) + assert(wrapper.attempt === task.attemptNumber) + assert(wrapper.launchTime === task.launchTime) + assert(wrapper.executorId === task.executorId) + assert(wrapper.host === task.host) + assert(wrapper.status === task.status) + assert(wrapper.taskLocality === task.taskLocality.toString()) + assert(wrapper.speculative === task.speculative) } } - // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. - s1Tasks.foreach { task => - val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), - Some(1L), None, true, false, None) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( - task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + // Send two executor metrics update. Only update one metric to avoid a lot of boilerplate code. + // The tasks are distributed among the two executors, so the executor-level metrics should + // hold half of the cummulative value of the metric being updated. + Seq(1L, 2L).foreach { value => + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(value), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size * value) + } + + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size * value / 2) + } } - check[StageDataWrapper](key(stages.head)) { stage => - assert(stage.info.memoryBytesSpilled === s1Tasks.size) + // Blacklisting executor for stage + time += 1 + listener.onExecutorBlacklistedForStage(SparkListenerExecutorBlacklistedForStage( + time = time, + executorId = execIds.head, + taskFailures = 2, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappers = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappers.nonEmpty) + executorStageSummaryWrappers.foreach { exec => + // only the first executor is expected to be blacklisted + val expectedBlacklistedFlag = exec.executorId == execIds.head + assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) } - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") - .first(key(stages.head)).last(key(stages.head)).asScala.toSeq - assert(execs.size > 0) - execs.foreach { exec => - assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.head.stageId)) + } + + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "2.example.com", // this is where the second executor is hosted + executorFailures = 1, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappersForNode = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappersForNode.nonEmpty) + executorStageSummaryWrappersForNode.foreach { exec => + // both executor is expected to be blacklisted + assert(exec.info.isBlacklistedForStage === true) } // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 val reattempt = newAttempt(s1Tasks.head, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt)) assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) @@ -275,13 +322,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](s1Tasks.head.taskId) { task => - assert(task.info.status === s1Tasks.head.status) - assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + assert(task.status === s1Tasks.head.status) + assert(task.errorMessage == Some(TaskResultLost.toErrorString)) } check[TaskDataWrapper](reattempt.taskId) { task => - assert(task.info.index === s1Tasks.head.index) - assert(task.info.attempt === reattempt.attemptNumber) + assert(task.index === s1Tasks.head.index) + assert(task.attempt === reattempt.attemptNumber) } // Kill one task, restart it. @@ -289,7 +336,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val killed = s1Tasks.drop(1).head killed.finishTime = time killed.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskKilled("killed"), killed, null)) check[JobDataWrapper](1) { job => @@ -303,21 +350,21 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](killed.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some("killed")) + assert(task.index === killed.index) + assert(task.errorMessage === Some("killed")) } // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. time += 1 val denied = newAttempt(killed, nextTaskId()) val denyReason = TaskCommitDenied(1, 1, 1) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, denied)) time += 1 denied.finishTime = time denied.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", denyReason, denied, null)) check[JobDataWrapper](1) { job => @@ -331,13 +378,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](denied.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some(denyReason.toErrorString)) + assert(task.index === killed.index) + assert(task.errorMessage === Some(denyReason.toErrorString)) } // Start a new attempt. val reattempt2 = newAttempt(denied, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt2)) // Succeed all tasks in stage 1. @@ -350,7 +397,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 pending.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", Success, task, s1Metrics)) } @@ -370,10 +417,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { pending.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.errorMessage === None) - assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) - assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) - assert(wrapper.info.duration === Some(task.duration)) + assert(wrapper.errorMessage === None) + assert(wrapper.executorCpuTime === 2L) + assert(wrapper.executorRunTime === 4L) + assert(wrapper.duration === task.duration) } } @@ -396,6 +443,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numCompleteTasks === pending.size) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set()) + } + // Submit stage 2. time += 1 stages.last.submissionTime = Some(time) @@ -410,17 +461,32 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) } + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "1.example.com", + executorFailures = 1, + stageId = stages.last.stageId, + stageAttemptId = stages.last.attemptId)) + + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) + } + // Start and fail all tasks of stage 2. time += 1 val s2Tasks = createTasks(4, execIds) s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, + stages.last.attemptNumber, + task)) } time += 1 s2Tasks.foreach { task => task.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptNumber, "taskType", TaskResultLost, task, null)) } @@ -455,7 +521,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // - Re-submit stage 2, all tasks, and succeed them and the stage. val oldS2 = stages.last - val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks, oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) time += 1 @@ -466,14 +532,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val newS2Tasks = createTasks(4, execIds) newS2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptNumber, task)) } time += 1 newS2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, - task, null)) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptNumber, "taskType", + Success, task, null)) } time += 1 @@ -522,14 +588,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val j2s2Tasks = createTasks(4, execIds) j2s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, + j2Stages.last.attemptNumber, task)) } time += 1 j2s2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptNumber, "taskType", Success, task, null)) } @@ -888,6 +955,27 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.count(classOf[StageDataWrapper]) === 3) assert(store.count(classOf[RDDOperationGraphWrapper]) === 3) + val dropped = stages.drop(1).head + + // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be + // calculated, we need at least one finished task. The code in AppStatusStore uses + // `executorRunTime` to detect valid tasks, so that metric needs to be updated in the + // task end event. + time += 1 + val task = createTasks(1, Array("1")).head + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + + time += 1 + task.markFinished(TaskState.FINISHED, time) + val metrics = TaskMetrics.empty + metrics.setExecutorRunTime(42L) + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + "taskType", Success, task, metrics)) + + new AppStatusStore(store) + .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) + stages.drop(1).foreach { s => time += 1 s.completionTime = Some(time) @@ -899,6 +987,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { intercept[NoSuchElementException] { store.read(classOf[StageDataWrapper], Array(2, 0)) } + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0) val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") time += 1 @@ -919,13 +1008,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val tasks = createTasks(2, Array("1")) tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) // Start a 3rd task. The finished tasks should be deleted. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -934,7 +1023,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a 4th task. The first task should be deleted, even if it's still running. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -942,6 +1031,165 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("eviction should respect job completion time") { + val testConf = conf.clone().set(MAX_RETAINED_JOBS, 2) + val listener = new AppStatusListener(store, testConf, true) + + // Start job 1 and job 2 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Nil, null)) + time += 1 + listener.onJobStart(SparkListenerJobStart(2, time, Nil, null)) + + // Stop job 2 before job 1 + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Start job 3 and job 2 should be evicted. + time += 1 + listener.onJobStart(SparkListenerJobStart(3, time, Nil, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[JobDataWrapper], 2) + } + } + + test("eviction should respect stage completion time") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + // Start stage 1 and stage 2 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + + // Stop stage 2 before stage 1 + time += 1 + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Start stage 3 and stage 2 should be evicted. + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + + test("skipped stages should be evicted before completed stages") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Sart job 1 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start and stop stage 1 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 and stage 2 will become SKIPPED + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Submit stage 3 and verify stage 2 is evicted + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + + test("eviction should respect task completion time") { + val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + // Start task 1 and task 2 + val tasks = createTasks(3, Array("1")) + tasks.take(2).foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task)) + } + + // Stop task 2 before task 1 + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + + // Start task 3 and task 2 should be evicted. + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) + assert(store.count(classOf[TaskDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[TaskDataWrapper], tasks(1).id) + } + } + + test("lastStageAttempt should fail when the stage doesn't exist") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1) + val listener = new AppStatusListener(store, testConf, true) + val appStore = new AppStatusStore(store) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Make stage 3 complete before stage 2 so that stage 3 will be evicted + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + stage3.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage3)) + + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + + assert(appStore.asOption(appStore.lastStageAttempt(1)) === None) + assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2)) + assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) @@ -960,7 +1208,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala new file mode 100644 index 000000000000..92f90f3d96dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.TaskMetricDistributions +import org.apache.spark.util.Distribution +import org.apache.spark.util.kvstore._ + +class AppStatusStoreSuite extends SparkFunSuite { + + private val uiQuantiles = Array(0.0, 0.25, 0.5, 0.75, 1.0) + private val stageId = 1 + private val attemptId = 1 + + test("quantile calculation: 1 task") { + compareQuantiles(1, uiQuantiles) + } + + test("quantile calculation: few tasks") { + compareQuantiles(4, uiQuantiles) + } + + test("quantile calculation: more tasks") { + compareQuantiles(100, uiQuantiles) + } + + test("quantile calculation: lots of tasks") { + compareQuantiles(4096, uiQuantiles) + } + + test("quantile calculation: custom quantiles") { + compareQuantiles(4096, Array(0.01, 0.33, 0.5, 0.42, 0.69, 0.99)) + } + + test("quantile cache") { + val store = new InMemoryStore() + (0 until 4096).foreach { i => store.write(newTaskData(i)) } + + val appStore = new AppStatusStore(store) + + appStore.taskSummary(stageId, attemptId, Array(0.13d)) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "13")) + } + + appStore.taskSummary(stageId, attemptId, Array(0.25d)) + val d1 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + + // Add a new task to force the cached quantile to be evicted, and make sure it's updated. + store.write(newTaskData(4096)) + appStore.taskSummary(stageId, attemptId, Array(0.25d, 0.50d, 0.73d)) + + val d2 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + assert(d1.taskCount != d2.taskCount) + + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "50")) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "73")) + } + + assert(store.count(classOf[CachedQuantile]) === 2) + } + + private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = { + val store = new InMemoryStore() + val values = (0 until count).map { i => + val task = newTaskData(i) + store.write(task) + i.toDouble + }.toArray + + val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, quantiles).get + val dist = new Distribution(values, 0, values.length).getQuantiles(quantiles.sorted) + + dist.zip(summary.executorRunTime).foreach { case (expected, actual) => + assert(expected === actual) + } + } + + private def newTaskData(i: Int): TaskDataWrapper = { + new TaskDataWrapper( + i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, stageId, attemptId) + } + +} diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala new file mode 100644 index 000000000000..9e74e86ad54b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status + +import java.util.Date + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +class AppStatusUtilsSuite extends SparkFunSuite { + + test("schedulerDelay") { + val runningTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "RUNNING", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 0L, + executorDeserializeCpuTime = 0L, + executorRunTime = 0L, + executorCpuTime = 0L, + resultSize = 0L, + jvmGcTime = 0L, + resultSerializationTime = 0L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 0L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) + + val finishedTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "SUCCESS", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 5L, + executorDeserializeCpuTime = 3L, + executorRunTime = 90L, + executorCpuTime = 10L, + resultSize = 100L, + jvmGcTime = 10L, + resultSerializationTime = 2L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 100L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f1..9c0699bc981f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 629eed49b04c..b19d8ebf72c6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.storage import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.duration._ @@ -44,8 +43,9 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} +import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -1325,9 +1325,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { val tryAgainMsg = "test_spark_20640_try_again" + val timingoutExecutor = "timingoutExecutor" + val tryAgainExecutor = "tryAgainExecutor" + val succeedingExecutor = "succeedingExecutor" + // a server which delays response 50ms and must try twice for success. def newShuffleServer(port: Int): (TransportServer, Int) = { - val attempts = new mutable.HashMap[String, Int]() + val failure = new Exception(tryAgainMsg) + val success = ByteBuffer.wrap(new Array[Byte](0)) + + var secondExecutorFailedOnce = false + var thirdExecutorFailedOnce = false + val handler = new NoOpRpcHandler { override def receive( client: TransportClient, @@ -1335,15 +1344,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE callback: RpcResponseCallback): Unit = { val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) msgObj match { - case exec: RegisterExecutor => - Thread.sleep(50) - val attempt = attempts.getOrElse(exec.execId, 0) + 1 - attempts(exec.execId) = attempt - if (attempt < 2) { - callback.onFailure(new Exception(tryAgainMsg)) - return - } - callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + + case exec: RegisterExecutor if exec.execId == timingoutExecutor => + () // No reply to generate client-side timeout + + case exec: RegisterExecutor + if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce => + secondExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == tryAgainExecutor => + callback.onSuccess(success) + + case exec: RegisterExecutor + if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce => + thirdExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == succeedingExecutor => + callback.onSuccess(success) + } } } @@ -1352,6 +1372,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val transCtx = new TransportContext(transConf, handler, true) (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } + val candidatePort = RandomUtils.nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") @@ -1360,21 +1381,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.service.port", shufflePort.toString) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - var e = intercept[SparkException]{ - makeBlockManager(8000, "executor1") + var e = intercept[SparkException] { + makeBlockManager(8000, timingoutExecutor) }.getMessage assert(e.contains("TimeoutException")) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - e = intercept[SparkException]{ - makeBlockManager(8000, "executor2") + e = intercept[SparkException] { + makeBlockManager(8000, tryAgainExecutor) }.getMessage assert(e.contains(tryAgainMsg)) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") - makeBlockManager(8000, "executor3") + makeBlockManager(8000, succeedingExecutor) server.close() } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5bfe9905ff17..a2997dbd1b1a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -99,7 +100,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +245,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -310,7 +315,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -352,6 +357,47 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("big blocks are not checked for corruption") { + val corruptBuffer = mockCorruptBuffer(10000L) + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ) + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ) + + val transfer = createMockTransfer( + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths), + (remoteBmId, remoteBlockLengths) + ).toIterator + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + // Blocks should be returned without exceptions. + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -368,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -383,16 +424,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -450,7 +491,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -468,17 +510,52 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress.toIterator, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index da198f946fd6..ca352387055f 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -51,27 +51,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsed === 60L) } - test("storage status update non-RDD blocks") { - val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) - assert(status.blocks.size === 3) - assert(status.memUsed === 160L) - assert(status.memRemaining === 840L) - assert(status.diskUsed === 140L) - } - - test("storage status remove non-RDD blocks") { - val status = storageStatus1 - status.removeBlock(TestBlockId("foo")) - status.removeBlock(TestBlockId("faa")) - assert(status.blocks.size === 1) - assert(status.blocks.contains(TestBlockId("fee"))) - assert(status.memUsed === 10L) - assert(status.memRemaining === 990L) - assert(status.diskUsed === 20L) - } - // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) @@ -95,85 +74,6 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocks.contains(RDDBlockId(2, 2))) assert(status.rddBlocks.contains(RDDBlockId(2, 3))) assert(status.rddBlocks.contains(RDDBlockId(2, 4))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1))) - assert(status.rddBlocksById(2).size === 3) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 30L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 80L) - assert(status.rddStorageLevel(0) === Some(memAndDisk)) - assert(status.rddStorageLevel(1) === Some(memAndDisk)) - assert(status.rddStorageLevel(2) === Some(memAndDisk)) - - // Verify default values for RDDs that don't exist - assert(status.rddBlocksById(10).isEmpty) - assert(status.memUsedByRdd(10) === 0L) - assert(status.diskUsedByRdd(10) === 0L) - assert(status.rddStorageLevel(10) === None) - } - - test("storage status update RDD blocks") { - val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) - assert(status.blocks.size === 7) - assert(status.rddBlocks.size === 5) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(2).size === 3) - assert(status.memUsedByRdd(0) === 0L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 20L) - assert(status.diskUsedByRdd(0) === 0L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 1060L) - } - - test("storage status remove RDD blocks") { - val status = storageStatus2 - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(1, 1)) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 4)) - assert(status.blocks.size === 3) - assert(status.rddBlocks.size === 2) - assert(status.rddBlocks.contains(RDDBlockId(0, 0))) - assert(status.rddBlocks.contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 0) - assert(status.rddBlocksById(2).size === 1) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 0L) - assert(status.memUsedByRdd(2) === 10L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 0L) - assert(status.diskUsedByRdd(2) === 20L) - } - - test("storage status containsBlock") { - val status = storageStatus2 - // blocks that actually exist - assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan"))) - assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man"))) - assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0))) - assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1))) - assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2))) - assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3))) - assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4))) - // blocks that don't exist - assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan"))) - assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0))) } test("storage status getBlock") { @@ -191,40 +91,6 @@ class StorageSuite extends SparkFunSuite { assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0))) } - test("storage status num[Rdd]Blocks") { - val status = storageStatus2 - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(100).size === status.numRddBlocksById(100)) - status.removeBlock(RDDBlockId(4, 0)) - status.removeBlock(RDDBlockId(10, 10)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - // remove a block that doesn't exist - status.removeBlock(RDDBlockId(1000, 999)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000)) - } test("storage status memUsed, diskUsed, externalBlockStoreUsed") { val status = storageStatus2 @@ -237,17 +103,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -273,65 +128,6 @@ class StorageSuite extends SparkFunSuite { Seq(info0, info1) } - test("StorageUtils.updateRddInfo") { - val storageStatuses = stockStorageStatuses - val rddInfos = stockRDDInfos - StorageUtils.updateRddInfo(rddInfos, storageStatuses) - assert(rddInfos(0).storageLevel === memAndDisk) - assert(rddInfos(0).numCachedPartitions === 5) - assert(rddInfos(0).memSize === 5L) - assert(rddInfos(0).diskSize === 10L) - assert(rddInfos(0).externalBlockStoreSize === 0L) - assert(rddInfos(1).storageLevel === memAndDisk) - assert(rddInfos(1).numCachedPartitions === 3) - assert(rddInfos(1).memSize === 3L) - assert(rddInfos(1).diskSize === 6L) - assert(rddInfos(1).externalBlockStoreSize === 0L) - } - - test("StorageUtils.getRddBlockLocations") { - val storageStatuses = stockStorageStatuses - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0.contains(RDDBlockId(0, 0))) - assert(blockLocations0.contains(RDDBlockId(0, 1))) - assert(blockLocations0.contains(RDDBlockId(0, 2))) - assert(blockLocations0.contains(RDDBlockId(0, 3))) - assert(blockLocations0.contains(RDDBlockId(0, 4))) - assert(blockLocations1.contains(RDDBlockId(1, 0))) - assert(blockLocations1.contains(RDDBlockId(1, 1))) - assert(blockLocations1.contains(RDDBlockId(1, 2))) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - - test("StorageUtils.getRddBlockLocations with multiple locations") { - val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - private val offheap = StorageLevel.OFF_HEAP // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap // and offheap blocks @@ -373,21 +169,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) - assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) - - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } private def storageStatus4: StorageStatus = { diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 661d0d48d2f3..6044563f7dde 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,22 +28,82 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore -import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} +import org.apache.spark.status.config._ +import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable} class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 + test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + val statusStore = AppStatusStore.createLiveStore(conf) + try { + val stageData = new StageData( + status = StageStatus.ACTIVE, + stageId = 1, + attemptId = 1, + numTasks = 1, + numActiveTasks = 1, + numCompleteTasks = 1, + numFailedTasks = 1, + numKilledTasks = 1, + numCompletedIndices = 1, + + executorRunTime = 1L, + executorCpuTime = 1L, + submissionTime = None, + firstTaskLaunchedTime = None, + completionTime = None, + failureReason = None, + + inputBytes = 1L, + inputRecords = 1L, + outputBytes = 1L, + outputRecords = 1L, + shuffleReadBytes = 1L, + shuffleReadRecords = 1L, + shuffleWriteBytes = 1L, + shuffleWriteRecords = 1L, + memoryBytesSpilled = 1L, + diskBytesSpilled = 1L, + + name = "stage1", + description = Some("description"), + details = "detail", + schedulingPool = "pool1", + + rddIds = Seq(1), + accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), + tasks = None, + executorSummary = None, + killedTasksSummary = Map.empty + ) + val taskTable = new TaskPagedTable( + stageData, + basePath = "/a/b/c", + currentTime = 0, + pageSize = 10, + sortColumn = "Index", + desc = false, + store = statusStore + ) + val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet + assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet) + } finally { + statusStore.close() + } + } + test("peak execution memory should displayed") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } @@ -52,7 +112,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. */ - private def renderStagePage(conf: SparkConf): Seq[Node] = { + private def renderStagePage(): Seq[Node] = { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) val statusStore = AppStatusStore.createLiveStore(conf) val listener = statusStore.listener.get diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 326546787ab6..e86cadfeebcf 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -131,7 +131,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val storageJson = getJson(ui, "storage/rdd") storageJson.children.length should be (1) - (storageJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) + (storageJson.children.head \ "storageLevel").extract[String] should be ( + StorageLevels.DISK_ONLY.description) val rddJson = getJson(ui, "storage/rdd/0") (rddJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) @@ -150,7 +151,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedStorageJson = getJson(ui, "storage/rdd") updatedStorageJson.children.length should be (1) - (updatedStorageJson \ "storageLevel").extract[String] should be ( + (updatedStorageJson.children.head \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( @@ -204,7 +205,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } val stageJson = getJson(sc.ui.get, "stages") stageJson.children.length should be (1) - (stageJson \ "status").extract[String] should be (StageStatus.FAILED.name()) + (stageJson.children.head \ "status").extract[String] should be (StageStatus.FAILED.name()) // Regression test for SPARK-2105 class NotSerializable @@ -325,11 +326,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") - (jobJson \ "numTasks").extract[Int]should be (2) - (jobJson \ "numCompletedTasks").extract[Int] should be (3) - (jobJson \ "numFailedTasks").extract[Int] should be (1) - (jobJson \ "numCompletedStages").extract[Int] should be (2) - (jobJson \ "numFailedStages").extract[Int] should be (1) + (jobJson \\ "numTasks").extract[Int]should be (2) + (jobJson \\ "numCompletedTasks").extract[Int] should be (3) + (jobJson \\ "numFailedTasks").extract[Int] should be (1) + (jobJson \\ "numCompletedStages").extract[Int] should be (2) + (jobJson \\ "numFailedStages").extract[Int] should be (1) val stageJson = getJson(sc.ui.get, "stages") for { @@ -656,7 +657,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) - val attempts = (appListJsonAst \ "attempts").children + val attempts = (appListJsonAst.children.head \ "attempts").children attempts.size should be (1) (attempts(0) \ "completed").extract[Boolean] should be (false) parseDate(attempts(0) \ "startTime") should be (sc.startTime) @@ -706,6 +707,23 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + test("stages page should show skipped stages") { + withSpark(newSparkContext()) { sc => + val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache() + rdd.count() + rdd.count() + + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/stages") + find(id("skipped")).get.text should be("Skipped Stages (1)") + } + val stagesJson = getJson(sc.ui.get, "stages") + stagesJson.children.size should be (4) + val stagesStatus = stagesJson.children.map(_ \ "status") + stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed8..fe0a9a471a65 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 4abbb8e7894f..74b72d940eee 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("SparkListenerJobStart backward compatibility") { // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) @@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite { // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. // Also, SparkListenerJobEnd did not have a "Completion Time" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) .removeField({ _._1 == "Submission Time"}) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index eaea6b030c15..3b4273184f1e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -648,6 +648,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") + sourceDir.mkdir() val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e..3e56db5ea116 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(MemoryBlock.fromLongArray(ref)) - val tmp = new Array[Long](size/2) - val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) + val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096..ddf3740e76a7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/multi-channel/BGRA_alpha_60.png new file mode 100644 index 000000000000..913637cd2828 Binary files /dev/null and b/data/mllib/images/multi-channel/BGRA_alpha_60.png differ diff --git a/dev/.gitignore b/dev/.gitignore index 4a6027429e0d..c673922f36d2 100644 --- a/dev/.gitignore +++ b/dev/.gitignore @@ -1 +1,2 @@ pep8*.py +pycodestyle*.py diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 607234b4068d..9552d001a079 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -73,8 +73,10 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation +app-20180109111548-0000 app-20161115172038-0000 app-20161116163331-0000 +application_1516285256255_0012 local-1422981759269 local-1422981780767 local-1425081759269 @@ -103,3 +105,4 @@ META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin +kafka-source-initial-offset-future-version.bin diff --git a/dev/appveyor-guide.md b/dev/appveyor-guide.md index d2e00b484727..a842f39b3049 100644 --- a/dev/appveyor-guide.md +++ b/dev/appveyor-guide.md @@ -1,6 +1,6 @@ # AppVeyor Guides -Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documenation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. +Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documentation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. ### Setting up AppVeyor @@ -45,7 +45,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 16 35 -- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access to the Github logs (e.g. commits). +- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access the Github logs (e.g. commits). 2016-09-04 11 10 22 @@ -87,7 +87,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 29 41 -- If the build is running, "CANCEL BUILD" buttom appears. Click this button top cancel the current build. +- If the build is running, "CANCEL BUILD" button appears. Click this button to cancel the current build. 2016-08-30 1 11 13 diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c71137468054..c00b00b84540 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -92,9 +92,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds @@ -164,8 +164,6 @@ if [[ "$1" == "package" ]]; then tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ --detach-sig spark-$SPARK_VERSION.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ - spark-$SPARK_VERSION.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION @@ -215,9 +213,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $R_DIST_NAME.asc \ --detach-sig $R_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $R_DIST_NAME > \ - $R_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ $R_DIST_NAME.sha512 @@ -234,9 +229,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $PYTHON_DIST_NAME.asc \ --detach-sig $PYTHON_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $PYTHON_DIST_NAME > \ - $PYTHON_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $PYTHON_DIST_NAME > \ $PYTHON_DIST_NAME.sha512 @@ -247,9 +239,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ - spark-$SPARK_VERSION-bin-$NAME.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 @@ -382,18 +371,11 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there + # this must have .asc and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 730138195e5f..32f6cbb29f0b 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -185,6 +185,8 @@ def get_commits(tag): "graphx": "GraphX", "input/output": CORE_COMPONENT, "java api": "Java API", + "k8s": "Kubernetes", + "kubernetes": "Kubernetes", "mesos": "Mesos", "ml": "MLlib", "mllib": "MLlib", diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index a7fce2ede0ea..c3d1dd444b50 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -47,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -122,18 +126,22 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -146,9 +154,11 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar @@ -167,12 +177,12 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar @@ -186,5 +196,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 94b2e98d85e7..290867035f91 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -47,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -122,19 +126,23 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -147,9 +155,11 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar @@ -168,12 +178,12 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar @@ -187,5 +197,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 000000000000..97ad65a4096c --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,221 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.8.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar +automaton-1.11-8.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.58.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.8.jar +commons-compress-1.4.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.4.jar +httpcore-4.4.8.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.8.jar +java-xmlbuilder-1.1.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jets3t-0.9.4.jar +jetty-webapp-9.3.20.v20170531.jar +jetty-xml-9.3.20.v20170531.jar +jline-2.12.1.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.8.2.jar +parquet-common-1.8.2.jar +parquet-encoding-1.8.2.jar +parquet-format-2.3.1.jar +parquet-hadoop-1.8.2.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.8.2.jar +protobuf-java-2.5.0.jar +py4j-0.10.6.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.5.9.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm5-shaded-4.4.jar +xz-1.0.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/lint-java b/dev/lint-java index c2e80538ef2a..1f0b0c8379ed 100755 --- a/dev/lint-java +++ b/dev/lint-java @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" -ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) if test ! -z "$ERRORS"; then echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/lint-python b/dev/lint-python index df8df037a5f6..f738af9c4976 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,7 +21,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" # Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" SPHINXBUILD=${SPHINXBUILD:=sphinx-build} @@ -30,23 +30,22 @@ SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" -# Get pep8 at runtime so that we don't rely on it being installed on the build server. -#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -#+ TODOs: -#+ - Download pep8 from PyPI. It's more "official". -PEP8_VERSION="1.7.0" -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" +# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. +# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. +PYCODESTYLE_VERSION="2.3.1" +PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" +PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" -if [ ! -e "$PEP8_SCRIPT_PATH" ]; then - curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" curl_status="$?" if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." exit "$curl_status" fi fi @@ -61,26 +60,26 @@ export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" # There is no need to write this output to a file -#+ first, but we do so so that the check status can -#+ be output before the report, like with the -#+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" -pep8_status="${PIPESTATUS[0]}" +# first, but we do so so that the check status can +# be output before the report, like with the +# scalastyle and RAT checks. +python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" +pycodestyle_status="${PIPESTATUS[0]}" -if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then +if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then lint_status=0 else lint_status=1 fi if [ "$lint_status" -ne 0 ]; then - echo "PEP8 checks failed." - cat "$PEP8_REPORT_PATH" - rm "$PEP8_REPORT_PATH" + echo "pycodestyle checks failed." + cat "$PYCODESTYLE_REPORT_PATH" + rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" else - echo "PEP8 checks passed." - rm "$PEP8_REPORT_PATH" + echo "pycodestyle checks passed." + rm "$PYCODESTYLE_REPORT_PATH" fi # Check that the documentation builds acceptably, skip check if sphinx is not installed. diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 7245163ea2a5..84233c64caa9 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -72,9 +72,17 @@ while (( "$#" )); do --help) exit_with_usage ;; - *) + --*) + echo "Error: $1 is not supported" + exit_with_usage + ;; + -*) break ;; + *) + echo "Error: $1 is not supported" + exit_with_usage + ;; esac shift done @@ -117,15 +125,21 @@ if [ ! "$(command -v "$MVN")" ] ; then exit -1; fi -VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ # because we use "set -o pipefail" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 57ca8400b6f3..5ea205fbed4a 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -30,6 +30,7 @@ import re import subprocess import sys +import traceback import urllib2 try: @@ -298,24 +299,37 @@ def choose_jira_assignee(issue, asf_jira): Prompt the user to choose who to assign the issue to in jira, given a list of candidates, including the original reporter and all commentors """ - reporter = issue.fields.reporter - commentors = map(lambda x: x.author, issue.fields.comment.comments) - candidates = set(commentors) - candidates.add(reporter) - candidates = list(candidates) - print("JIRA is unassigned, choose assignee") - for idx, author in enumerate(candidates): - annotations = ["Reporter"] if author == reporter else [] - if author in commentors: - annotations.append("Commentor") - print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - assignee = raw_input("Enter number of user to assign to (blank to leave unassigned):") - if assignee == "": - return None - else: - assignee = candidates[int(assignee)] - asf_jira.assign_issue(issue.key, assignee.key) - return assignee + while True: + try: + reporter = issue.fields.reporter + commentors = map(lambda x: x.author, issue.fields.comment.comments) + candidates = set(commentors) + candidates.add(reporter) + candidates = list(candidates) + print("JIRA is unassigned, choose assignee") + for idx, author in enumerate(candidates): + if author.key == "apachespark": + continue + annotations = ["Reporter"] if author == reporter else [] + if author in commentors: + annotations.append("Commentor") + print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) + raw_assignee = raw_input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") + if raw_assignee == "": + return None + else: + try: + id = int(raw_assignee) + assignee = candidates[id] + except: + # assume it's a user id, and try to assign (might fail, we just prompt again) + assignee = asf_jira.user(raw_assignee) + asf_jira.assign_issue(issue.key, assignee.key) + return assignee + except: + traceback.print_exc() + print("Error assigning JIRA, try again (or leave blank and fix manually)") def resolve_jira_issues(title, merge_branches, comment): @@ -496,7 +510,7 @@ def main(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) try: main() except: diff --git a/dev/mima b/dev/mima index 1e3ca9700bc0..cd2694ff4d3d 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c..1321c2be4c19 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -25,10 +25,10 @@ shopt -s nullglob FWDIR="$(cd "$(dirname "$0")"/..; pwd)" cd "$FWDIR" -echo "Constucting virtual env for testing" +echo "Constructing virtual env for testing" VIRTUALENV_BASE=$(mktemp -d) -# Clean up the virtual env enviroment used if we created one. +# Clean up the virtual env environment used if we created one. function delete_virtualenv() { echo "Cleaning up temporary directory - $VIRTUALENV_BASE" rm -rf "$VIRTUALENV_BASE" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 03fd6ff0fba4..5bc03e41d1f2 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -20,7 +20,7 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. # Environment variables are populated by the code here: -#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 +# https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 914eb93622d5..3960a0de6253 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,8 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 300m) - tests_timeout = "250m" + # must be less than the timeout configured on Jenkins (currently 350m) + tests_timeout = "300m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/dev/run-tests.py b/dev/run-tests.py index 7e6f7ff06035..164c1e2200aa 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -576,9 +576,15 @@ def main(): for f in changed_files): # run_java_style_checks() pass - if not changed_files or any(f.endswith(".py") for f in changed_files): + if not changed_files or any(f.endswith("lint-python") + or f.endswith("tox.ini") + or f.endswith(".py") + for f in changed_files): run_python_style_checks() - if not changed_files or any(f.endswith(".R") for f in changed_files): + if not changed_files or any(f.endswith(".R") + or f.endswith("lint-r") + or f.endswith(".lintr") + for f in changed_files): run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment @@ -615,7 +621,7 @@ def _test(): import doctest failure_count = doctest.testmod()[0] if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8..b8053df05fa2 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -24,6 +24,7 @@ ERRORS=$(echo -e "q\n" \ -Pkinesis-asl \ -Pmesos \ -Pkafka-0-8 \ + -Pkubernetes \ -Pyarn \ -Pflume \ -Phive \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dd..dfea762db98c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -25,10 +25,10 @@ @total_ordering class Module(object): """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. + A module is the basic abstraction in our test runner script. Each module consists of a set + of source files, a set of test commands, and a set of dependencies on other modules. We use + modules to define a dependency graph that let us determine which tests to run based on which + files have changed. """ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] @@ -539,7 +540,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py index 6c67b4504bc3..8b2688d20039 100644 --- a/dev/sparktestsupport/toposort.py +++ b/dev/sparktestsupport/toposort.py @@ -43,8 +43,8 @@ def toposort(data): """Dependencies are expressed as a dictionary whose keys are items and whose values are a set of dependent items. Output is a list of sets in topological order. The first set consists of items with no -dependences, each subsequent set consists of items that depend upon -items in the preceeding sets. +dependencies, each subsequent set consists of items that depend upon +items in the preceding sets. """ # Special case empty input. @@ -59,7 +59,7 @@ def toposort(data): v.discard(k) # Find all items that don't depend on anything. extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) - # Add empty dependences where needed. + # Add empty dependencies where needed. data.update({item: set() for item in extra_items_in_deps}) while True: ordered = set(item for item, dep in data.items() if len(dep) == 0) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e0..2fbd6b5e98f7 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,11 +29,12 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh index d9a347fe24a8..25fdbccac4dd 100755 --- a/dev/tests/pr_merge_ability.sh +++ b/dev/tests/pr_merge_ability.sh @@ -23,9 +23,9 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` # Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` +# known as `sha1` in `run-tests-jenkins` # ghprbActualCommit="$1" diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 41c5d3ee8cb3..479d1851fe0b 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -23,7 +23,7 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` ghprbActualCommit="$1" @@ -31,7 +31,7 @@ ghprbActualCommit="$1" # master commit and the tip of the pull request branch. # By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only -# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# non-test files, we can get changes introduced in the PR and not anything else added to master # since the PR was branched. # Handle differences between GNU and BSD sed diff --git a/dev/tox.ini b/dev/tox.ini index eb8b1eb2c288..583c1eaaa966 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -[pep8] -ignore=E402,E731,E241,W503,E226 +[pycodestyle] +ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* diff --git a/docs/README.md b/docs/README.md index 225bb1b2040d..dbea4d64c429 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,7 +5,7 @@ here with the Spark source code. You can also find documentation specific to rel Spark at http://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the -documentation yourself. Why build it yourself? So that you have the docs that corresponds to +documentation yourself. Why build it yourself? So that you have the docs that correspond to whichever version of Spark you currently have checked out of revision control. ## Prerequisites @@ -22,10 +22,13 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' ``` -(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) +Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. + +Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 5.0.1, which is updated if the version is mismatched. ## Generating the Documentation HTML @@ -62,12 +65,12 @@ $ PRODUCTION=1 jekyll build ## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs) -You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory. +You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory. Similarly, you can build just the PySpark docs by running `make html` from the -`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and -the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh` +`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as +public in `__init__.py`. The SparkR docs can be built by running `$SPARK_HOME/R/create-docs.sh`, and +the SQL docs can be built by running `$SPARK_HOME/sql/create-docs.sh` after [building Spark](https://github.com/apache/spark#building-spark) first. When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various diff --git a/docs/_config.yml b/docs/_config.yml index dcc211204d76..095fadb93fe5 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.3.0 +SPARK_VERSION: 2.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.4.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.8" MESOS_VERSION: 1.0.0 diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ea1d438f529..1e91f12518e0 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -48,7 +48,7 @@ def render(context) begin code = File.open(@file).read.encode("UTF-8") rescue => e - # We need to explicitly exit on execptions here because Jekyll will silently swallow + # We need to explicitly exit on exceptions here because Jekyll will silently swallow # them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104) puts(e) puts(e.backtrace) diff --git a/docs/building-spark.md b/docs/building-spark.md index c391255a9159..0236bb05849a 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -113,7 +113,7 @@ Note: Flume support is deprecated as of Spark 2.3.0. ## Building submodules individually -It's possible to build Spark sub-modules using the `mvn -pl` option. +It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 751a192da4ff..ac1c33698893 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -27,13 +27,13 @@ description: Introduction to cloud storage support in Apache Spark SPARK_VERSION All major cloud providers offer persistent data storage in *object stores*. These are not classic "POSIX" file systems. In order to store hundreds of petabytes of data without any single points of failure, -object stores replace the classic filesystem directory tree +object stores replace the classic file system directory tree with a simpler model of `object-name => data`. To enable remote access, operations on objects are usually offered as (slow) HTTP REST operations. Spark can read and write data in object stores through filesystem connectors implemented in Hadoop or provided by the infrastructure suppliers themselves. -These connectors make the object stores look *almost* like filesystems, with directories and files +These connectors make the object stores look *almost* like file systems, with directories and files and the classic operations on them such as list, delete and rename. @@ -180,10 +180,10 @@ under the path, not the number of *new* files, so it can become a slow operation The size of the window needs to be set to handle this. 1. Files only appear in an object store once they are completely written; there -is no need for a worklow of write-then-rename to ensure that files aren't picked up +is no need for a workflow of write-then-rename to ensure that files aren't picked up while they are still being written. Applications can write straight to the monitored directory. -1. Streams should only be checkpointed to an store implementing a fast and +1. Streams should only be checkpointed to a store implementing a fast and atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 658e67f99dd7..7277e2fb2731 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,8 +52,8 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -* [Kubernetes](running-on-kubernetes.html) -- [Kubernetes](https://kubernetes.io/docs/concepts/overview/what-is-kubernetes/) -is an open-source platform that provides container-centric infrastructure. +* [Kubernetes](running-on-kubernetes.html) -- an open-source system for automating deployment, scaling, + and management of containerized applications. A third-party project (not supported by the Spark project) exists to add support for [Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. diff --git a/docs/configuration.md b/docs/configuration.md index 1189aea2aa71..8a1aacef8576 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -58,6 +58,10 @@ The following format is accepted: 1t or 1tb (tebibytes = 1024 gibibytes) 1p or 1pb (pebibytes = 1024 tebibytes) +While numbers without units are generally interpreted as bytes, a few are interpreted as KiB or MiB. +See documentation of individual configuration properties. Specifying units is desirable where +possible. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For @@ -75,7 +79,7 @@ Then, you can supply configuration values at runtime: {% endhighlight %} The Spark shell and [`spark-submit`](submitting-applications.html) -tool support two ways to load configurations dynamically. The first are command line options, +tool support two ways to load configurations dynamically. The first is command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. Running `./bin/spark-submit --help` will show the entire list of these options. @@ -136,9 +140,9 @@ of the most common options to set are: spark.driver.maxResultSize 1g - Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). - Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size - is above this limit. + Limit of total size of serialized results of all partitions for each Spark action (e.g. + collect) in bytes. Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total + size is above this limit. Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory and memory overhead of objects in JVM). Setting a proper limit can protect the driver from out-of-memory errors. @@ -148,10 +152,10 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
      Note: In client mode, this config must not be set through the SparkConf + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB + unless otherwise specified (e.g. 1g, 2g). +
      + Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option or in your default properties file. @@ -161,27 +165,28 @@ of the most common options to set are: spark.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is - memory that accounts for things like VM overheads, interned strings, other native overheads, etc. - This tends to grow with the container size (typically 6-10%). This option is currently supported - on YARN and Kubernetes. + The amount of off-heap memory to be allocated per driver in cluster mode, in MiB unless + otherwise specified. This is memory that accounts for things like VM overheads, interned strings, + other native overheads, etc. This tends to grow with the container size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. spark.executor.memory 1g - Amount of memory to use per executor process (e.g. 2g, 8g). + Amount of memory to use per executor process, in MiB unless otherwise specified. + (e.g. 2g, 8g). spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that - accounts for things like VM overheads, interned strings, other native overheads, etc. This tends - to grow with the executor size (typically 6-10%). This option is currently supported on YARN and - Kubernetes. + The amount of off-heap memory to be allocated per executor, in MiB unless otherwise specified. + This is memory that accounts for things like VM overheads, interned strings, other native + overheads, etc. This tends to grow with the executor size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. @@ -323,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc @@ -408,7 +418,7 @@ Apart from these, the following properties are also available, and may be useful false Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), - or it will be displayed before the driver exiting. It also can be dumped into disk by + or it will be displayed before the driver exits. It also can be dumped into disk by sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. @@ -431,8 +441,9 @@ Apart from these, the following properties are also available, and may be useful 512m Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g). + If the memory used during aggregation goes above this amount, it will spill the data into disks. @@ -440,7 +451,7 @@ Apart from these, the following properties are also available, and may be useful true Reuse Python worker or not. If yes, it will use a fixed number of Python workers, - does not need to fork() a Python process for every tasks. It will be very useful + does not need to fork() a Python process for every task. It will be very useful if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -540,9 +551,10 @@ Apart from these, the following properties are also available, and may be useful spark.reducer.maxSizeInFlight 48m - Maximum size of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + Maximum size of map outputs to fetch simultaneously from each reduce task, in MiB unless + otherwise specified. Since each output requires us to create a buffer to receive it, this + represents a fixed memory overhead per reduce task, so keep it small unless you have a + large amount of memory. @@ -551,7 +563,7 @@ Apart from these, the following properties are also available, and may be useful This configuration limits the number of remote requests to fetch blocks at any given point. When the number of hosts in the cluster increase, it might lead to very large number - of in-bound connections to one or more nodes, causing the workers to fail under load. + of inbound connections to one or more nodes, causing the workers to fail under load. By allowing it to limit the number of fetch requests, this scenario can be mitigated. @@ -570,9 +582,9 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The remote block will be fetched to disk when size of the block is above this threshold. + The remote block will be fetched to disk when size of the block is above this threshold in bytes. This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, this feature can only be worked when external shuffle service is newer than Spark 2.2. @@ -589,8 +601,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer 32k - Size of the in-memory buffer for each shuffle file output stream. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + Size of the in-memory buffer for each shuffle file output stream, in KiB unless otherwise + specified. These buffers reduce the number of disk seeks and system calls made in creating + intermediate shuffle files. @@ -651,7 +664,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.index.cache.size 100m - Cache entries limited to the specified memory footprint. + Cache entries limited to the specified memory footprint in bytes. @@ -685,9 +698,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.accurateBlockThreshold 100 * 1024 * 1024 - When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the - size accurately if it's above this config. This helps to prevent OOM by avoiding - underestimating shuffle block size when fetch shuffle blocks. + Threshold in bytes above which the size of shuffle blocks in HighlyCompressedMapStatus is + accurately recorded. This helps to prevent OOM by avoiding underestimating shuffle + block size when fetch shuffle blocks. @@ -704,30 +717,6 @@ Apart from these, the following properties are also available, and may be useful When we fail to register to the external shuffle service, we will retry for maxAttempts times. - - spark.io.encryption.enabled - false - - Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption - be enabled when using this feature. - - - - spark.io.encryption.keySizeBits - 128 - - IO encryption key size in bits. Supported values are 128, 192 and 256. - - - - spark.io.encryption.keygen.algorithm - HmacSHA1 - - The algorithm to use when generating the IO encryption key. The supported algorithms are - described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm - Name Documentation. - - ### Spark UI @@ -779,7 +768,7 @@ Apart from these, the following properties are also available, and may be useful spark.eventLog.buffer.kb 100k - Buffer size in KB to use when writing to output streams. + Buffer size to use when writing to output streams, in KiB unless otherwise specified. @@ -885,6 +874,23 @@ Apart from these, the following properties are also available, and may be useful How many dead executors the Spark UI and status APIs remember before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark Web UI. The filter should be a + standard + javax servlet Filter. + +
      Filter parameters can also be specified in the configuration, by setting config entries + of the form spark.<class name of filter>.param.<param name>=<value> + +
      For example: +
      spark.ui.filters=com.test.filter1 +
      spark.com.test.filter1.param.name1=foo +
      spark.com.test.filter1.param.name2=bar + + ### Compression and Serialization @@ -917,7 +923,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.lz4.blockSize 32k - Block size used in LZ4 compression, in the case when LZ4 compression codec + Block size in bytes used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -925,7 +931,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.snappy.blockSize 32k - Block size used in Snappy compression, in the case when Snappy compression codec + Block size in bytes used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. @@ -941,7 +947,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.zstd.bufferSize 32k - Buffer size used in Zstd compression, in the case when Zstd compression codec + Buffer size in bytes used in Zstd compression, in the case when Zstd compression codec is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it might increase the compression cost because of excessive JNI call overhead. @@ -1001,8 +1007,8 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.max 64m - Maximum allowable size of Kryo serialization buffer. This must be larger than any - object you attempt to serialize and must be less than 2048m. + Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified. + This must be larger than any object you attempt to serialize and must be less than 2048m. Increase this if you get a "buffer limit exceeded" exception inside Kryo. @@ -1010,9 +1016,9 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer 64k - Initial size of Kryo's serialization buffer. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max if needed. + Initial size of Kryo's serialization buffer, in KiB unless otherwise specified. + Note that there will be one buffer per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max if needed. @@ -1086,7 +1092,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory + use is enabled, then spark.memory.offHeap.size must be positive. @@ -1094,7 +1101,8 @@ Apart from these, the following properties are also available, and may be useful 0 The absolute amount of memory in bytes which can be used for off-heap allocation. - This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This setting has no impact on heap memory usage, so if your executors' total memory consumption + must fit within some hard limit then be sure to shrink your JVM heap size accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true. @@ -1202,9 +1210,9 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.blockSize 4m - Size of each piece of a block for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + Size of each piece of a block for TorrentBroadcastFactory, in KiB unless otherwise + specified. Too large a value decreases parallelism during broadcast (makes it slower); however, + if it is too small, BlockManager might take a performance hit. @@ -1284,8 +1292,8 @@ Apart from these, the following properties are also available, and may be useful spark.files.openCostInBytes 4194304 (4 MB) - The estimated cost to open a file, measured by the number of bytes could be scanned in the same - time. This is used when putting multiple files into a partition. It is better to over estimate, + The estimated cost to open a file, measured by the number of bytes could be scanned at the same + time. This is used when putting multiple files into a partition. It is better to overestimate, then the partitions with small files will be faster than partitions with bigger files. @@ -1312,7 +1320,7 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryMapThreshold 2m - Size of a block above which Spark memory maps when reading a block from disk. + Size in bytes of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system. @@ -1436,6 +1444,15 @@ Apart from these, the following properties are also available, and may be useful Duration for an RPC remote endpoint lookup operation to wait before timing out. + + spark.core.connection.ack.wait.timeout + spark.network.timeout + + How long for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + ### Scheduling @@ -1501,7 +1518,7 @@ Apart from these, the following properties are also available, and may be useful 0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained + (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarse-grained mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, @@ -1622,7 +1639,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + failure happens. If external shuffle service is enabled, then the whole node will be blacklisted. @@ -1710,7 +1727,7 @@ Apart from these, the following properties are also available, and may be useful When spark.task.reaper.enabled = true, this setting specifies a timeout after which the executor JVM will kill itself if a killed task has not stopped running. The default value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose - of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + of this setting is to act as a safety-net to prevent runaway noncancellable tasks from rendering an executor unusable. @@ -1741,6 +1758,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio @@ -1785,6 +1803,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + spark.dynamicAllocation.executorAllocationRatio + 1 + + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overriden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings + + spark.dynamicAllocation.schedulerBacklogTimeout 1s @@ -1807,313 +1842,8 @@ Apart from these, the following properties are also available, and may be useful ### Security - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
      Property NameDefaultMeaning
      spark.acls.enablefalse - Whether Spark acls should be enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
      spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things do not work. Putting a "*" in the list means any user can have the - privilege of admin. -
      spark.admin.acls.groupsEmpty - Comma separated list of groups that have view and modify access to all Spark jobs. - This can be used if you have a set of administrators or developers who help maintain and debug - the underlying infrastructure. Putting a "*" in the list means any user in any group can have - the privilege of admin. The user groups are obtained from the instance of the groups mapping - provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
      spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user are determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can configured by this property. - A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider - which can be specified to resolve a list of groups for a user. - Note: This implementation supports only a Unix/Linux based environment. Windows environment is - currently not supported. However, a new platform/protocol can be supported by implementing - the trait org.apache.spark.security.GroupMappingServiceProvider. -
      spark.authenticatefalse - Whether Spark authenticates its internal connections. See - spark.authenticate.secret if not running on YARN. -
      spark.authenticate.secretNone - Set the secret key used for Spark to authenticate between components. This needs to be set if - not running on YARN and authentication is enabled. -
      spark.network.crypto.enabledfalse - Enable encryption using the commons-crypto library for RPC and block transfer service. - Requires spark.authenticate to be enabled. -
      spark.network.crypto.keyLength128 - The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. -
      spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 - The key factory algorithm to use when generating encryption keys. Should be one of the - algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. -
      spark.network.crypto.saslFallbacktrue - Whether to fall back to SASL authentication if authentication fails using Spark's internal - mechanism. This is useful when the application is connecting to old shuffle services that - do not support the internal Spark authentication protocol. On the server side, this can be - used to block older clients from authenticating against a new shuffle service. -
      spark.network.crypto.config.*None - Configuration values for the commons-crypto library, such as which cipher implementations to - use. The config name should be the name of commons-crypto configuration without the - "commons.crypto" prefix. -
      spark.authenticate.enableSaslEncryptionfalse - Enable encrypted communication when authentication is - enabled. This is supported by the block transfer service and the - RPC endpoints. -
      spark.network.sasl.serverAlwaysEncryptfalse - Disable unencrypted connections for services that support SASL authentication. -
      spark.core.connection.ack.wait.timeoutspark.network.timeout - How long for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. -
      spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). Putting a "*" in - the list means any user can have access to modify it. -
      spark.modify.acls.groupsEmpty - Comma separated list of groups that have modify access to the Spark job. This can be used if you - have a set of administrators or developers from the same team to have access to control the job. - Putting a "*" in the list means any user in any group has the access to modify the Spark job. - The user groups are obtained from the instance of the groups mapping provider specified by - spark.user.groups.mapping. Check the entry spark.user.groups.mapping - for more details. -
      spark.ui.filtersNone - Comma separated list of filter class names to apply to the Spark web UI. The filter should be a - standard - javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of:
      - spark.<class name of filter>.params='param1=value1,param2=value2'
      - For example:
      - -Dspark.ui.filters=com.test.filter1
      - -Dspark.com.test.filter1.params='param1=foo,param2=testing' -
      spark.ui.view.aclsEmpty - Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. Putting a "*" in the list means any user can - have view access to this Spark job. -
      spark.ui.view.acls.groupsEmpty - Comma separated list of groups that have view access to the Spark web ui to view the Spark Job - details. This can be used if you have a set of administrators or developers or users who can - monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view - the Spark job details on the Spark web ui. The user groups are obtained from the instance of the - groups mapping provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
      - -### TLS / SSL - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
      Property NameDefaultMeaning
      spark.ssl.enabledfalse - Whether to enable SSL connections on all supported protocols. - -
      When spark.ssl.enabled is configured, spark.ssl.protocol - is required. - -
      All the SSL settings like spark.ssl.xxx where xxx is a - particular configuration property, denote the global configuration for all the supported - protocols. In order to override the global configuration for the particular protocol, - the properties must be overwritten in the protocol-specific namespace. - -
      Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Example values for YYY - include fs, ui, standalone, and - historyServer. See SSL - Configuration for details on hierarchical SSL configuration for services. -
      spark.ssl.[namespace].portNone - The port where the SSL service will listen on. - -
      The port must be defined within a namespace configuration; see - SSL Configuration for the available - namespaces. - -
      When not set, the SSL port will be derived from the non-SSL port for the - same service. A value of "0" will make the service bind to an ephemeral port. -
      spark.ssl.enabledAlgorithmsEmpty - A comma separated list of ciphers. The specified ciphers must be supported by JVM. - The reference list of protocols one can find on - this - page. - Note: If not set, it will use the default cipher suites of JVM. -
      spark.ssl.keyPasswordNone - A password to the private key in key-store. -
      spark.ssl.keyStoreNone - A path to a key-store file. The path can be absolute or relative to the directory where - the component is started in. -
      spark.ssl.keyStorePasswordNone - A password to the key-store. -
      spark.ssl.keyStoreTypeJKS - The type of the key-store. -
      spark.ssl.protocolNone - A protocol name. The protocol must be supported by JVM. The reference list of protocols - one can find on this - page. -
      spark.ssl.needClientAuthfalse - Set true if SSL needs client authentication. -
      spark.ssl.trustStoreNone - A path to a trust-store file. The path can be absolute or relative to the directory - where the component is started in. -
      spark.ssl.trustStorePasswordNone - A password to the trust-store. -
      spark.ssl.trustStoreTypeJKS - The type of the trust-store. -
      - +Please refer to the [Security](security.html) page for available options on how to secure different +Spark subsystems. ### Spark SQL @@ -2208,8 +1938,8 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.streaming.receiver.writeAheadLog.enable false - Enable write ahead logs for receivers. All the input data received through receivers - will be saved to write ahead logs that will allow it to be recovered after driver failures. + Enable write-ahead logs for receivers. All the input data received through receivers + will be saved to write-ahead logs that will allow it to be recovered after driver failures. See the deployment guide in the Spark Streaming programing guide for more details. @@ -2264,7 +1994,7 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.streaming.driver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the driver. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the metadata WAL on the driver. @@ -2273,7 +2003,7 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.streaming.receiver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the receivers. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the data WAL on the receivers. @@ -2455,7 +2185,7 @@ should be included on Spark's classpath: The location of these configuration files varies across Hadoop versions, but a common location is inside of `/etc/hadoop/conf`. Some tools create -configurations on-the-fly, but offer a mechanisms to download copies of them. +configurations on-the-fly, but offer a mechanism to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/conf/spark-env.sh` to a location containing the configuration files. @@ -2471,7 +2201,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defalut.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. @@ -2490,4 +2220,4 @@ Also, you can modify or add configurations at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ --conf spark.hadoop.abc.def=xyz \ myApp.jar -{% endhighlight %} \ No newline at end of file +{% endhighlight %} diff --git a/docs/css/pygments-default.css b/docs/css/pygments-default.css index 6247cd8396cf..a4d583b36660 100644 --- a/docs/css/pygments-default.css +++ b/docs/css/pygments-default.css @@ -5,7 +5,7 @@ To generate this, I had to run But first I had to install pygments via easy_install pygments I had to override the conflicting bootstrap style rules by linking to -this stylesheet lower in the html than the bootstap css. +this stylesheet lower in the html than the bootstrap css. Also, I was thrown off for a while at first when I was using markdown code block inside my {% highlight scala %} ... {% endhighlight %} tags diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 46225dc598da..35293348e3f3 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -491,7 +491,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts)( The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices` except that the user defined `map` function is applied to all vertices and can change the vertex property type. Because not all vertices may have a matching value in the input RDD the `map` -function takes an `Option` type. For example, we can setup a graph for PageRank by initializing +function takes an `Option` type. For example, we can set up a graph for PageRank by initializing vertex properties with their `outDegree`. @@ -708,7 +708,7 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodically checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): @@ -928,7 +928,7 @@ switch to 2D-partitioning or other heuristics included in GraphX.

      -Once the edges have be partitioned the key challenge to efficient graph-parallel computation is +Once the edges have been partitioned the key challenge to efficient graph-parallel computation is efficiently joining vertex attributes with the edges. Because real-world graphs typically have more edges than vertices, we move vertex attributes to the edges. Because not all partitions will contain edges adjacent to all vertices we internally maintain a routing table which identifies where @@ -969,7 +969,7 @@ A vertex is part of a triangle when it has two adjacent vertices with an edge be # Examples Suppose I want to build a graph from some text files, restrict the graph -to important relationships and users, run page-rank on the sub-graph, and +to important relationships and users, run page-rank on the subgraph, and then finally return attributes associated with the top users. I can do all of this in just a few lines with GraphX: diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index e6d881639a13..da90342406c8 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -23,7 +23,7 @@ run tasks and store data for that application. If multiple users need to share y different options to manage allocation, depending on the cluster manager. The simplest option, available on all cluster managers, is _static partitioning_ of resources. With -this approach, each application is given a maximum amount of resources it can use, and holds onto them +this approach, each application is given a maximum amount of resources it can use and holds onto them for its whole duration. This is the approach used in Spark's [standalone](spark-standalone.html) and [YARN](running-on-yarn.html) modes, as well as the [coarse-grained Mesos mode](running-on-mesos.html#mesos-run-modes). @@ -230,7 +230,7 @@ properties: * `minShare`: Apart from an overall weight, each pool can be given a _minimum shares_ (as a number of CPU cores) that the administrator would like it to have. The fair scheduler always attempts to meet all active pools' minimum shares before redistributing extra resources according to the weights. - The `minShare` property can therefore be another way to ensure that a pool can always get up to a + The `minShare` property can, therefore, be another way to ensure that a pool can always get up to a certain number of resources (e.g. 10 cores) quickly without giving it a high priority for the rest of the cluster. By default, each pool's `minShare` is 0. diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 2747f2df7cb1..375957e92cc4 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -77,7 +77,7 @@ Quasi-Newton methods in this case. This fallback is currently always enabled for L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively. -In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. +In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features is no more than 4096. For larger problems, use L-BFGS instead. ## Iteratively reweighted least squares (IRLS) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index bf979f3c73a5..d660655e193e 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark The `spark.ml` implementation of logistic regression also supports extracting a summary of the model over the training set. Note that the predictions and metrics which are stored as `DataFrame` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +`LogisticRegressionSummary` are annotated `@transient` and hence only available on the driver.
      @@ -97,10 +97,9 @@ only available on the driver. [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) provides a summary for a [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -111,10 +110,9 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) provides a summary for a [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). Continuing the earlier example: @@ -125,7 +123,8 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary) provides a summary for a [`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin **Examples** The following example shows how to train a multiclass logistic regression -model with elastic net regularization. +model with elastic net regularization, as well as extract the multiclass +training summary for evaluating the model.
      @@ -420,7 +420,7 @@ Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator`. For the base classifier, it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. @@ -908,7 +908,7 @@ Refer to the [R API docs](api/R/spark.survreg.html) for more details. belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -927,7 +927,7 @@ We implement a which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns -label, features and weight. Additionally IsotonicRegression algorithm has one +label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 58f2d4b531e7..8b0f287dc39a 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -35,7 +35,7 @@ but the ids must be within the integer value range. ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. diff --git a/docs/ml-features.md b/docs/ml-features.md index 72643137d96b..7aed2341584f 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -222,9 +222,9 @@ The `FeatureHasher` transformer operates on multiple columns. Each column may co numeric or categorical features. Behavior and handling of column data types is as follows: - Numeric columns: For numeric features, the hash value of the column name is used to map the -feature value to its index in the feature vector. Numeric features are never treated as -categorical, even when they are integers. You must explicitly convert numeric columns containing -categorical features to strings first. +feature value to its index in the feature vector. By default, numeric features are not treated +as categorical (even when they are integers). To treat them as categorical, specify the relevant +columns using the `categoricalCols` parameter. - String columns: For categorical features, the hash value of the string "column_name=value" is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with @@ -775,35 +775,43 @@ for more details on the API.
      -## OneHotEncoder +## OneHotEncoder (Deprecated since 2.3.0) -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). + +`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. + +## OneHotEncoderEstimator + +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). + +`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). **Examples**
      -Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
      -Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) +Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
      -Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% include_example python/ml/onehot_encoder_estimator_example.py %}
      @@ -1166,7 +1174,7 @@ for more details on the API. ## SQLTransformer `SQLTransformer` implements the transformations which are defined by SQL statement. -Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +Currently, we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` where `"__THIS__"` represents the underlying table of the input dataset. The select clause specifies the fields, constants, and expressions to display in the output, and can be any select clause that Spark SQL supports. Users can also @@ -1283,6 +1291,57 @@ for more details on the API.
    +## VectorSizeHint + +It can sometimes be useful to explicitly specify the size of the vectors for a column of +`VectorType`. For example, `VectorAssembler` uses size information from its input columns to +produce size information and metadata for its output column. While in some cases this information +can be obtained by inspecting the contents of the column, in a streaming dataframe the contents are +not available until the stream is started. `VectorSizeHint` allows a user to explicitly specify the +vector size for a column so that `VectorAssembler`, or other transformers that might +need to know vector size, can use that column as an input. + +To use `VectorSizeHint` a user must set the `inputCol` and `size` parameters. Applying this +transformer to a dataframe produces a new dataframe with updated metadata for `inputCol` specifying +the vector size. Downstream operations on the resulting dataframe can get this size using the +meatadata. + +`VectorSizeHint` can also take an optional `handleInvalid` parameter which controls its +behaviour when the vector column contains nulls or vectors of the wrong size. By default +`handleInvalid` is set to "error", indicating an exception should be thrown. This parameter can +also be set to "skip", indicating that rows containing invalid values should be filtered out from +the resulting dataframe, or "optimistic", indicating that the column should not be checked for +invalid values and all rows should be kept. Note that the use of "optimistic" can cause the +resulting dataframe to be in an inconsistent state, me:aning the metadata for the column +`VectorSizeHint` was applied to does not match the contents of that column. Users should take care +to avoid this kind of inconsistent state. + +
    +
    + +Refer to the [VectorSizeHint Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala %} +
    + +
    + +Refer to the [VectorSizeHint Java docs](api/java/org/apache/spark/ml/feature/VectorSizeHint.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java %} +
    + +
    + +Refer to the [VectorSizeHint Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example python/ml/vector_size_hint_example.py %} +
    +
    + ## QuantileDiscretizer `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned diff --git a/docs/ml-guide.md b/docs/ml-guide.md index f6288e7c32d9..aea07be34cb8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -72,32 +72,31 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 [^1]: To learn more about the benefits and background of system optimised natives, you may wish to watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). -# Highlights in 2.2 +# Highlights in 2.3 -The list below highlights some of the new features and enhancements added to MLlib in the `2.2` +The list below highlights some of the new features and enhancements added to MLlib in the `2.3` release of Spark: -* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all - users or items, matching the functionality in `mllib` - ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). - Performance was also improved for both `ml` and `mllib` - ([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and - [SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587)) -* [`Correlation`](ml-statistics.html#correlation) and - [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames` - ([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and - [SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635)) -* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining - ([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503)) -* `GLM` now supports the full `Tweedie` family - ([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929)) -* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset - ([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568)) -* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine) - for linear Support Vector Machine classification - ([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709)) -* Logistic regression now supports constraints on the coefficients during training - ([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047)) +* Built-in support for reading images into a `DataFrame` was added +([SPARK-21866](https://issues.apache.org/jira/browse/SPARK-21866)). +* [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) was added, and should be +used instead of the existing `OneHotEncoder` transformer. The new estimator supports +transforming multiple columns. +* Multiple column support was also added to `QuantileDiscretizer` and `Bucketizer` +([SPARK-22397](https://issues.apache.org/jira/browse/SPARK-22397) and +[SPARK-20542](https://issues.apache.org/jira/browse/SPARK-20542)) +* A new [`FeatureHasher`](ml-features.html#featurehasher) transformer was added + ([SPARK-13969](https://issues.apache.org/jira/browse/SPARK-13969)). +* Added support for evaluating multiple models in parallel when performing cross-validation using +[`TrainValidationSplit` or `CrossValidator`](ml-tuning.html) +([SPARK-19357](https://issues.apache.org/jira/browse/SPARK-19357)). +* Improved support for custom pipeline components in Python (see +[SPARK-21633](https://issues.apache.org/jira/browse/SPARK-21633) and +[SPARK-21542](https://issues.apache.org/jira/browse/SPARK-21542)). +* `DataFrame` functions for descriptive summary statistics over vector columns +([SPARK-19634](https://issues.apache.org/jira/browse/SPARK-19634)). +* Robust linear regression with Huber loss +([SPARK-3181](https://issues.apache.org/jira/browse/SPARK-3181)). # Migration guide @@ -109,42 +108,40 @@ and the migration guide below will explain all changes between releases. ### Breaking changes -There are no breaking changes. +* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner +and better accommodate the addition of the multi-class summary. This is a breaking change for user +code that casts a `LogisticRegressionTrainingSummary` to a +`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail +(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which +will still work correctly for both multinomial and binary cases. ### Deprecations and changes of behavior **Deprecations** -There are no deprecations. +* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the +new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) +(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that +`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but +`OneHotEncoderEstimator` will be kept as an alias). **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial), in 2.2 and earlier version, - the `OneVsRest` parallelism would be parallelism of the default threadpool in scala. - -## From 2.1 to 2.2 - -### Breaking changes - -There are no breaking changes. - -### Deprecations and changes of behavior - -**Deprecations** - -There are no deprecations. - -**Changes of behavior** - -* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): - Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). - **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. -* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): - Fixed inconsistency between Python and Scala APIs for `Param.copy` method. -* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): - `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception - would always be thrown regardless of the setting of the `handleInvalid` parameter. + The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and + earlier versions, the level of parallelism was set to the default threadpool size in Scala. +* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156): + The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than + `1`. This will cause training results to be different between `2.3` and earlier versions. +* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681): + Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients + when some features had zero variance. +* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957): + Tree algorithms now use mid-points for split values. This may change results from model training. +* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657): + Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent + with the output in R. This may change results from model training in this scenario. ## Previous Spark versions diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index 687d7c893036..e4736411fb5f 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -7,6 +7,29 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). +## From 2.1 to 2.2 + +### Breaking changes + +There are no breaking changes. + +### Deprecations and changes of behavior + +**Deprecations** + +There are no deprecations. + +**Changes of behavior** + +* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): + Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). + **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. +* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): + Fixed inconsistency between Python and Scala APIs for `Param.copy` method. +* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): + `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception + would always be thrown regardless of the setting of the `handleInvalid` parameter. + ## From 2.0 to 2.1 ### Breaking changes @@ -324,7 +347,7 @@ rather than using the old parameter class `Strategy`. These new training method separate classification and regression, and they replace specialized parameter types with simple `String` types. -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the [Decision Trees Guide](mllib-decision-tree.html#examples). ## From 0.9 to 1.0 diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index aa92c0a37c0f..e22e9003c30f 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -## Saving and Loading Pipelines +## ML persistence: Saving and Loading Pipelines -Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. +As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage. + +ML persistence works across Scala, Java and Python. However, R currently uses a modified format, +so models saved in R can only be loaded back in R; this should be fixed in the future and is +tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572). + +### Backwards compatibility for ML persistence + +In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML +model or Pipeline in one version of Spark, then you should be able to load it back and use it in a +future version of Spark. However, there are rare exceptions, described below. + +Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark +version X loadable by Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Yes; these are backwards compatible. +* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible. + +Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Identical behavior, except for bug fixes. + +For both model persistence and model behavior, any breaking changes across a minor version or patch +version are reported in the Spark version release notes. If a breakage is not reported in release +notes, then it should be treated as a bug to be fixed. # Code examples diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 54d9cd21909d..028bfec465ba 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -103,7 +103,7 @@ Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.m In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in - the case of `CrossValidator`. It is therefore less expensive, + the case of `CrossValidator`. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large. Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index df2be92d860e..dc6b095f5d59 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -42,7 +42,7 @@ The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute Within -Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the +Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact, the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 76a00f18b3b9..b2300028e151 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -31,7 +31,7 @@ following parameters: ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. @@ -60,7 +60,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
    -In the following example we load rating data. Each row consists of a user, a product and a rating. +In the following example, we load rating data. Each row consists of a user, a product and a rating. We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS$) method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 35cee3275e3b..5066bb29387d 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -350,7 +350,7 @@ which is a tuple of `(Int, Int, Matrix)`. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -In general the use of non-deterministic RDDs can lead to errors. +In general, the use of non-deterministic RDDs can lead to errors. ### RowMatrix diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index a72680d52a26..4e6b4530942f 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -91,7 +91,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an [Principal component analysis (PCA)](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical method to find a rotation such that the first coordinate has the largest variance -possible, and each succeeding coordinate in turn has the largest variance possible. The columns of +possible, and each succeeding coordinate, in turn, has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. `spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7f277543d2e9..d9dbbab4840a 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -13,7 +13,7 @@ of the model on some criteria, which depends on the application and its requirem suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, -regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 75aea7060187..bb29f65c0322 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -105,7 +105,7 @@ p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top \]` where $V$ is the vocabulary size. -The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to $O(\log(V))$ @@ -278,8 +278,8 @@ for details on the API. multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. -Qu8T948*1# -Denoting the `scalingVec` as "`w`," this transformation may be written as: + +Denoting the `scalingVec` as "`w`", this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index ca84551506b2..99cab98c690c 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -9,7 +9,7 @@ displayTitle: Regression - RDD-based API belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -28,7 +28,7 @@ best fitting the original data points. which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent -label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 034e89e25000..73f6e206ca54 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -425,7 +425,7 @@ We create our model by initializing the weights to zero and register the streams testing then start the job. Printing predictions alongside true labels lets us easily see the result. -Finally we can save text files with data to the training or testing folders. +Finally, we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` the model will update. Anytime a text file is placed in `args(1)` you will see predictions. diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 14d76a6e41e2..04758903da89 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -121,7 +121,7 @@ computation of the sum of the partial results from each worker machine is perfor standard spark routines. If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in -each iteration is exact (sub)gradient descent. In this case there is no randomness and no +each iteration is exact (sub)gradient descent. In this case, there is no randomness and no variance in the used step directions. On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to @@ -135,7 +135,7 @@ algorithm in the family of quasi-Newton methods to solve the optimization proble quadratic without evaluating the second partial derivatives of the objective function to construct the Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no vertical scalability issue (the number of training features) when computing the Hessian matrix -explicitly in Newton's method. As a result, L-BFGS often achieves rapider convergence compared with +explicitly in Newton's method. As a result, L-BFGS often achieves more rapid convergence compared with other first-order optimization. ### Choosing an Optimization Method diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index d3530908706d..f56756543792 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API * Table of contents {:toc} -## `spark.mllib` supported models +## spark.mllib supported models `spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). @@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a - + diff --git a/docs/monitoring.md b/docs/monitoring.md index f8d3ce91a069..6eaf33135744 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -80,7 +80,10 @@ The history server can be configured as follows:
    `spark.mllib` modelPMML model
    spark.mllib modelPMML model
    -### Spark configuration options +### Spark History Server Configuration Options + +Security options for the Spark History Server are covered more detail in the +[Security](security.html#web-ui) page. @@ -118,7 +121,7 @@ The history server can be configured as follows: @@ -160,41 +163,6 @@ The history server can be configured as follows: Location of the kerberos keytab file for the History Server. - - - - - - - - - - - - - - - @@ -246,7 +214,7 @@ incomplete attempt or the final successful attempt. 2. Incomplete applications are only updated intermittently. The time between updates is defined by the interval between checks for changed files (`spark.history.fs.update.interval`). -On larger clusters the update interval may be set to large values. +On larger clusters, the update interval may be set to large values. The way to view a running application is actually to view its own web UI. 3. Applications which exited without registering themselves as completed will be listed @@ -347,6 +315,13 @@ can be identified by their `[attempt-id]`. In the API listed below, when running + + + + @@ -407,7 +382,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running
    Property NameDefaultMeaning
    The number of applications to retain UI data for in the cache. If this cap is exceeded, then the oldest applications will be removed from the cache. If an application is not in the cache, - it will have to be loaded from disk if its accessed from the UI. + it will have to be loaded from disk if it is accessed from the UI.
    spark.history.ui.acls.enablefalse - Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had - set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via - spark.ui.view.acls and groups specified via spark.ui.view.acls.groups - when the application was run will also have authorization to view that application. - If disabled, no access control checks are made. -
    spark.history.ui.admin.aclsempty - Comma separated list of users/administrators that have view access to all the Spark applications in - history server. By default only the users permitted to view the application at run-time could - access the related application history, with this, configured users/administrators could also - have the permission to access it. - Putting a "*" in the list means any user can have the privilege of admin. -
    spark.history.ui.admin.acls.groupsempty - Comma separated list of groups that have view access to all the Spark applications in - history server. By default only the groups permitted to view the application at run-time could - access the related application history, with this, configured groups could also - have the permission to access it. - Putting a "*" in the list means any group can have the privilege of admin. -
    spark.history.fs.cleaner.enabled false/applications/[app-id]/executors A list of all active executors for the given application.
    /applications/[app-id]/executors/[executor-id]/threads + Stack traces of all the threads running within the given active executor. + Not available via the history server. +
    /applications/[app-id]/allexecutors A list of all(active and dead) executors for the given application.
    -The number of jobs and stages which can retrieved is constrained by the same retention +The number of jobs and stages which can be retrieved is constrained by the same retention mechanism of the standalone Spark UI; `"spark.ui.retainedJobs"` defines the threshold value triggering garbage collection on jobs, and `spark.ui.retainedStages` that for stages. Note that the garbage collection takes place on playback: it is possible to retrieve @@ -422,10 +397,10 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future as a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. -Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is +Note that even when examining the UI of running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. @@ -447,7 +422,7 @@ configuration property. If, say, users wanted to set the metrics namespace to the name of the application, they can set the `spark.metrics.namespace` property to a value like `${spark.app.name}`. This value is then expanded appropriately by Spark and is used as the root namespace of the metrics system. -Non driver and executor metrics are never prefixed with `spark.app.id`, nor does the +Non-driver and executor metrics are never prefixed with `spark.app.id`, nor does the `spark.metrics.namespace` property have any such affect on such metrics. Spark's metrics are decoupled into different diff --git a/docs/quick-start.md b/docs/quick-start.md index 200b97230e86..f1a2096cd4db 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,11 +11,11 @@ This tutorial provides a quick introduction to using Spark. We will first introd interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -To follow along with this guide, first download a packaged release of Spark from the +To follow along with this guide, first, download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. -Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. # Interactive Analysis with the Spark Shell @@ -47,7 +47,7 @@ scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. +Now let's transform this Dataset into a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) @@ -67,7 +67,7 @@ res3: Long = 15 ./bin/pyspark -Or if PySpark is installed with pip in your current enviroment: +Or if PySpark is installed with pip in your current environment: pyspark @@ -156,7 +156,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).alias("word")).groupBy("word").count() {% endhighlight %} -Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: +Here, we use the `explode` function in `select`, to transform a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() @@ -422,7 +422,7 @@ $ YOUR_SPARK_HOME/bin/spark-submit \ Lines with a: 46, Lines with b: 23 {% endhighlight %} -If you have PySpark pip installed into your enviroment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. +If you have PySpark pip installed into your environment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. {% highlight bash %} # Use the Python interpreter to run your application diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 29af159510e4..b6424090d2fe 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -91,7 +91,7 @@ so C libraries like NumPy can be used. It also works with PyPy 2.3+. Python 2.6 support was removed in Spark 2.2.0. -Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including including it in your setup.py as: +Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including it in your setup.py as: {% highlight python %} install_requires=[ @@ -818,7 +818,7 @@ The behavior of the above code is undefined, and may not work as intended. To ex The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. +In local mode, in some circumstances, the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3..e9e1f3e28060 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -8,14 +8,21 @@ title: Running Spark on Kubernetes Spark can run on clusters managed by [Kubernetes](https://kubernetes.io). This feature makes use of native Kubernetes scheduler that has been added to Spark. +**The Kubernetes scheduler is currently experimental. +In future versions, there may be behavioral changes around configuration, +container images and entrypoints.** + # Prerequisites * A runnable distribution of Spark 2.3 or above. * A running Kubernetes cluster at version >= 1.6 with access configured to it using [kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster, -you may setup a test cluster on your local machine using +you may set up a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -38,11 +45,10 @@ logs and remains in "completed" state in the Kubernetes API until it's eventuall Note that in the completed state, the driver pod does *not* use any computational or memory resources. -The driver and executor pod scheduling is handled by Kubernetes. It will be possible to affect Kubernetes scheduling -decisions for driver and executor pods using advanced primitives like -[node selectors](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) -and [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) -in a future release. +The driver and executor pod scheduling is handled by Kubernetes. It is possible to schedule the +driver and executor pods on a subset of available nodes through a [node selector](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) +using the configuration property for it. It will be possible to use more advanced +scheduling hints like [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) in a future release. # Submitting Applications to Kubernetes @@ -50,20 +56,19 @@ in a future release. Kubernetes requires users to supply images that can be deployed into containers within pods. The images are built to be run in a container runtime environment that Kubernetes supports. Docker is a container runtime environment that is -frequently used with Kubernetes. With Spark 2.3, there are Dockerfiles provided in the runnable distribution that can be customized -and built for your usage. +frequently used with Kubernetes. Spark (starting with version 2.3) ships with a Dockerfile that can be used for this +purpose, or customized to match an individual application's needs. It can be found in the `kubernetes/dockerfiles/` +directory. -You may build these docker images from sources. -There is a script, `sbin/build-push-docker-images.sh` that you can use to build and push -customized Spark distribution images consisting of all the above components. +Spark also ships with a `bin/docker-image-tool.sh` script that can be used to build and publish the Docker images to +use with the Kubernetes backend. Example usage is: - ./sbin/build-push-docker-images.sh -r -t my-tag build - ./sbin/build-push-docker-images.sh -r -t my-tag push - -Docker files are under the `kubernetes/dockerfiles/` directory and can be customized further before -building using the supplied script, or manually. +```bash +$ ./bin/docker-image-tool.sh -r -t my-tag build +$ ./bin/docker-image-tool.sh -r -t my-tag push +``` ## Cluster Mode @@ -76,8 +81,7 @@ $ bin/spark-submit \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.container.image= \ local:///path/to/examples.jar ``` @@ -95,7 +99,7 @@ must consist of lower case alphanumeric characters, `-`, and `.` and must start If you have a Kubernetes cluster setup, one way to discover the apiserver URL is by executing `kubectl cluster-info`. ```bash -kubectl cluster-info +$ kubectl cluster-info Kubernetes master is running at http://127.0.0.1:6443 ``` @@ -106,7 +110,7 @@ authenticating proxy, `kubectl proxy` to communicate to the Kubernetes API. The local proxy can be started by: ```bash -kubectl proxy +$ kubectl proxy ``` If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0.1:8001` can be used as the argument to @@ -118,37 +122,9 @@ This URI is the location of the example jar that is already in the Docker image. If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to by their appropriate remote URIs. Also, application dependencies can be pre-mounted into custom-built Docker images. Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the -`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. - -### Using Remote Dependencies -When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods -need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. This requires users to specify the container -image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users -simply add the following option to the `spark-submit` command to specify the init-container image: - -``` ---conf spark.kubernetes.initContainer.image= -``` - -The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and -`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., -the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: - -```bash -$ bin/spark-submit \ - --master k8s://https://: \ - --deploy-mode cluster \ - --name spark-pi \ - --class org.apache.spark.examples.SparkPi \ - --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar - --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 - --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ - --conf spark.kubernetes.initContainer.image= - https://path/to/examples.jar -``` +`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. The `local://` scheme is also required when referring to +dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission +client's local file system is currently not yet supported. ## Secret Management Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a @@ -164,10 +140,6 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` -Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the -init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the -init-container of the executor. - ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -179,7 +151,7 @@ Logs can be accessed using the Kubernetes API and the `kubectl` CLI. When a Spar to stream logs from the application using: ```bash -kubectl -n= logs -f +$ kubectl -n= logs -f ``` The same logs can also be accessed through the @@ -192,12 +164,12 @@ The UI associated with any application can be accessed locally using [`kubectl port-forward`](https://kubernetes.io/docs/tasks/access-application-cluster/port-forward-access-application-cluster/#forward-a-local-port-to-a-port-on-the-pod). ```bash -kubectl port-forward 4040:4040 +$ kubectl port-forward 4040:4040 ``` Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -206,17 +178,17 @@ are errors during the running of the application, often, the best way to investi To get some basic information about the scheduling decisions made around the driver pod, you can run: ```bash -kubectl describe pod +$ kubectl describe pod ``` If the pod has encountered a runtime error, the status can be probed further using: ```bash -kubectl logs +$ kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features @@ -249,7 +221,7 @@ that allows driver pods to create pods and services under the default Kubernetes [RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to be used by the driver pod through the configuration property -`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod +`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example, to make the driver pod use the `spark` service account, a user simply adds the following option to the `spark-submit` command: ``` @@ -260,7 +232,7 @@ To create a custom service account, a user can use the `kubectl create serviceac following command creates a service account named `spark`: ```bash -kubectl create serviceaccount spark +$ kubectl create serviceaccount spark ``` To grant a service account a `Role` or `ClusterRole`, a `RoleBinding` or `ClusterRoleBinding` is needed. To create @@ -269,7 +241,7 @@ for `ClusterRoleBinding`) command. For example, the following command creates an namespace and grants it to the `spark` service account created above: ```bash -kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default +$ kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default ``` Note that a `Role` can only be used to grant access to resources (like pods) within a single namespace, whereas a @@ -319,21 +291,27 @@ specific to Spark on Kubernetes. - spark.kubernetes.driver.container.image + spark.kubernetes.container.image (none) - Container image to use for the driver. - This is usually of the form example.com/repo/spark-driver:v1.0.0. - This configuration is required and must be provided by the user. + Container image to use for the Spark application. + This is usually of the form example.com/repo/spark:v1.0.0. + This configuration is required and must be provided by the user, unless explicit + images are provided for each different container type. + + + + spark.kubernetes.driver.container.image + (value of spark.kubernetes.container.image) + + Custom container image to use for the driver. spark.kubernetes.executor.container.image - (none) + (value of spark.kubernetes.container.image) - Container image to use for the executors. - This is usually of the form example.com/repo/spark-executor:v1.0.0. - This configuration is required and must be provided by the user. + Custom container image to use for executors. @@ -543,14 +521,6 @@ specific to Spark on Kubernetes. to avoid name conflicts. - - spark.kubernetes.executor.podNamePrefix - (none) - - Prefix for naming the executor pods. - If not set, the executor pod name is set to driver pod name suffixed by an integer. - - spark.kubernetes.executor.lostCheck.maxAttempts 10 @@ -579,14 +549,23 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + spark.kubernetes.executor.request.cores + (none) + + Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task + parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. + spark.kubernetes.executor.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. @@ -607,51 +586,12 @@ specific to Spark on Kubernetes. the Driver process. The user can specify multiple of these to set multiple environment variables. - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.timeout - 300s - - Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into - the driver and executor pods. - - - - spark.kubernetes.mountDependencies.maxSimultaneousDownloads - 5 - - Maximum number of remote dependencies to download simultaneously in a driver or executor pod. - - - - spark.kubernetes.initContainer.image - (none) - - Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. - - spark.kubernetes.driver.secrets.[SecretName] (none) Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, - spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the driver pod. + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. @@ -659,8 +599,7 @@ specific to Spark on Kubernetes. (none) Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, - spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the executor pod. + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. - \ No newline at end of file + diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 382cbfd5301b..3c2a1501ca69 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -82,6 +82,27 @@ a Spark driver program configured to connect to Mesos. Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure `spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. +## Authenticating to Mesos + +When Mesos Framework authentication is enabled it is necessary to provide a principal and secret by which to authenticate Spark to Mesos. Each Spark job will register with Mesos as a separate framework. + +Depending on your deployment environment you may wish to create a single set of framework credentials that are shared across all users or create framework credentials for each user. Creating and managing framework credentials should be done following the Mesos [Authentication documentation](http://mesos.apache.org/documentation/latest/authentication/). + +Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. + +Additionally, if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. + +### Credential Specification Preference Order + +Please note that if you specify multiple ways to obtain the credentials then the following preference order applies. Spark will use the first valid value found and any subsequent values are ignored: + +- `spark.mesos.principal` configuration setting +- `SPARK_MESOS_PRINCIPAL` environment variable +- `spark.mesos.principal.file` configuration setting +- `SPARK_MESOS_PRINCIPAL_FILE` environment variable + +An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -154,7 +175,7 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispacther, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. +By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. @@ -204,7 +225,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the Spark application will consume all resources offered to it by Mesos, -so we of course urge you to set this variable in any sort of +so we, of course, urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -212,14 +233,14 @@ The scheduler will start executors round-robin on the offers Mesos gives it, but there are no spread guarantees, as Mesos does not provide such guarantees on the offer stream. -In this mode spark executors will honor port allocation if such is -provided from the user. Specifically if the user defines +In this mode Spark executors will honor port allocation if such is +provided from the user. Specifically, if the user defines `spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the user, ephemeral ports are used as usual. This port honouring implementation -implies one task per host if the user defines a port. In the future network +implies one task per host if the user defines a port. In the future network, isolation shall be supported. The benefit of coarse-grained mode is much lower startup overhead, but @@ -427,7 +448,14 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.principal (none) - Set the principal with which Spark framework will use to authenticate with Mesos. + Set the principal with which Spark framework will use to authenticate with Mesos. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL`. + + + + spark.mesos.principal.file + (none) + + Set the file containing the principal with which Spark framework will use to authenticate with Mesos. Allows specifying the principal indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL_FILE`. @@ -435,7 +463,15 @@ See the [configuration page](configuration.html) for information on Spark config (none) Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when - authenticating with the registry. + authenticating with the registry. You can also specify this via the environment variable `SPARK_MESOS_SECRET`. + + + + spark.mesos.secret.file + (none) + + Set the file containing the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when + authenticating with the registry. Allows for specifying the secret indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_SECRET_FILE`. @@ -450,7 +486,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + Attribute-based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting applies only to executors. Refer to Mesos Attributes & Resources for more information on attributes.
      diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e7edec599036..ceda8a3ae240 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -2,6 +2,8 @@ layout: global title: Running Spark on YARN --- +* This will become a table of contents (this text will be scraped). +{:toc} Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) @@ -35,7 +37,7 @@ For example: --executor-memory 2g \ --executor-cores 1 \ --queue thequeue \ - lib/spark-examples*.jar \ + examples/jars/spark-examples*.jar \ 10 The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. @@ -217,8 +219,8 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to - being added to YARN's distributed cache. For use in cases where the YARN service does not + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not support schemes that are supported by Spark, like http, https and ftp. @@ -265,19 +267,6 @@ To use a custom metrics.properties for the application master and executors, upd distribution. - - spark.yarn.access.hadoopFileSystems - (none) - - A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For - example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, - webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed - and Kerberos must be properly configured to be able to access them (either in the same realm - or in a trusted realm). Spark acquires security tokens for each of the filesystems so that - the Spark application can access those remote Hadoop filesystems. spark.yarn.access.namenodes - is deprecated, please use this instead. - - spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) @@ -373,31 +362,6 @@ To use a custom metrics.properties for the application master and executors, upd in YARN ApplicationReports, which can be used for filtering when querying YARN apps. - - spark.yarn.keytab - (none) - - The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) - - - - spark.yarn.principal - (none) - - Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) - - - - spark.yarn.kerberos.relogin.period - 1m - - How often to check whether the kerberos TGT should be renewed. This should be set to a value - that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). - The default value should be enough for most deployments. - - spark.yarn.config.gatewayPath (none) @@ -424,17 +388,6 @@ To use a custom metrics.properties for the application master and executors, upd See spark.yarn.config.gatewayPath. - - spark.security.credentials.${service}.enabled - true - - Controls whether to obtain credentials for services when security is enabled. - By default, credentials for all supported services are retrieved when those services are - configured, but it's possible to disable that behavior if it somehow conflicts with the - application being run. For further details please see - [Running in a Secure Cluster](running-on-yarn.html#running-in-a-secure-cluster) - - spark.yarn.rolledLog.includePattern (none) @@ -445,7 +398,7 @@ To use a custom metrics.properties for the application master and executors, upd yarn.nodemanager.log-aggregation.roll-monitoring-interval-seconds should be configured in yarn-site.xml. This feature can only be used with Hadoop 2.6.4+. The Spark log4j appender needs be changed to use - FileAppender or another appender that can handle the files being removed while its running. Based + FileAppender or another appender that can handle the files being removed while it is running. Based on the file name configured in the log4j configuration (like spark.log), the user should set the regex (spark*) to include all the log files that need to be aggregated. @@ -465,51 +418,107 @@ To use a custom metrics.properties for the application master and executors, upd - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example, you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. -# Running in a Secure Cluster +# Kerberos + +Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. + +In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home +directory, Spark will also automatically obtain delegation tokens for the service hosting the +staging directory of the Spark application. + +If an application needs to interact with other secure Hadoop filesystems, their URIs need to be +explicitly provided to Spark at launch time. This is done by listing them in the +`spark.yarn.access.hadoopFileSystems` property, described in the configuration section below. -As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to -authenticate principals associated with services and clients. This allows clients to -make requests of these authenticated services; the services to grant rights -to the authenticated principals. +The YARN integration also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. These +providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to +`false`, where `{service}` is the name of the credential provider. + +## YARN-specific Kerberos Configuration + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. This keytab + will be copied to the node running the YARN Application Master via the YARN Distributed Cache, and + will be used for renewing the login tickets and the delegation tokens periodically. Equivalent to + the --keytab command line argument. + +
      (Works also with the "local" master.) +
      spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure clusters. Equivalent to the + --principal command line argument. + +
      (Works also with the "local" master.) +
      spark.yarn.access.hadoopFileSystems(none) + A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For + example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the filesystems so that + the Spark application can access those remote Hadoop filesystems. +
      spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
      -Hadoop services issue *hadoop tokens* to grant access to the services and data. -Clients must first acquire tokens for the services they will access and pass them along with their -application as it is launched in the YARN cluster. +## Troubleshooting Kerberos -For a Spark application to interact with any of the Hadoop filesystem (for example hdfs, webhdfs, etc), HBase and Hive, it must acquire the relevant tokens -using the Kerberos credentials of the user launching the application -—that is, the principal whose identity will become that of the launched Spark application. +Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to +enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` +environment variable. -This is normally done at launch time: in a secure cluster Spark will automatically obtain a -token for the cluster's default Hadoop filesystem, and potentially for HBase and Hive. +```bash +export HADOOP_JAAS_DEBUG=true +``` -An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares -the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.security.credentials.hbase.enabled` is not set to `false`. +The JDK classes can be configured to enable extra logging of their Kerberos and +SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` +and `sun.security.spnego.debug=true` -Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration -includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.security.credentials.hive.enabled` is not set to `false`. +``` +-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` -If an application needs to interact with other secure Hadoop filesystems, then -the tokens needed to access these clusters must be explicitly requested at -launch time. This is done by listing them in the `spark.yarn.access.hadoopFileSystems` property. +All these options can be enabled in the Application Master: ``` -spark.yarn.access.hadoopFileSystems hdfs://ireland.example.org:8020/,webhdfs://frankfurt.example.org:50070/ +spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true +spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true ``` -Spark supports integrating with other security-aware services through Java Services mechanism (see -`java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` -should be available to Spark by listing their names in the corresponding file in the jar's -`META-INF/services` directory. These plug-ins can be disabled by setting -`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of -credential provider. +Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log +will include a list of all tokens obtained, and their expiry details -## Configuring the External Shuffle Service + +# Configuring the External Shuffle Service To start the Spark Shuffle Service on each `NodeManager` in your YARN cluster, follow these instructions: @@ -542,7 +551,7 @@ The following extra configuration options are available when the shuffle service -## Launching your application with Apache Oozie +# Launching your application with Apache Oozie Apache Oozie can launch Spark applications as part of a workflow. In a secure cluster, the launched application will need the relevant tokens to access the cluster's @@ -576,35 +585,7 @@ spark.security.credentials.hbase.enabled false The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. -## Troubleshooting Kerberos - -Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to -enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` -environment variable. - -```bash -export HADOOP_JAAS_DEBUG=true -``` - -The JDK classes can be configured to enable extra logging of their Kerberos and -SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` -and `sun.security.spnego.debug=true` - -``` --Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -All these options can be enabled in the Application Master: - -``` -spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true -spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log -will include a list of all tokens obtained, and their expiry details - -## Using the Spark History Server to replace the Spark Web UI +# Using the Spark History Server to replace the Spark Web UI It is possible to use the Spark History Server application page as the tracking URL for running applications when the application UI is disabled. This may be desirable on secure clusters, or to diff --git a/docs/security.md b/docs/security.md index 15aadf07cf87..8c0c66fb5a28 100644 --- a/docs/security.md +++ b/docs/security.md @@ -3,41 +3,323 @@ layout: global displayTitle: Spark Security title: Security --- +* This will become a table of contents (this text will be scraped). +{:toc} -Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: +# Spark RPC -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. -* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. +## Authentication -## Web UI +Spark currently supports authentication for RPC channels using a shared secret. Authentication can +be turned on by setting the `spark.authenticate` configuration parameter. -The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). +The exact mechanism used to generate and distribute the shared secret is deployment-specific. -### Authentication +For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle +generating and distributing the shared secret. Each application will use a unique shared secret. In +the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of +secrets to be secure. -A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. +This secret will be shared by all the daemons and applications, so this deployment configuration is +not as secure as the above, especially when considering multi-tenant clusters. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. -Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.authenticatefalseWhether Spark authenticates its internal connections.
      spark.authenticate.secretNone + The secret key used authentication. See above for when this configuration should be set. +
      + +## Encryption -## Event Logging +Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC +authentication must also be enabled and properly configured. AES encryption uses the +[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +configuration system allows access to that library's configuration for advanced users. -If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +There is also support for SASL-based encryption, although it should be considered deprecated. It +is still required when talking to shuffle services from Spark versions older than 2.2.0. -## Encryption +The following table describes the different options available for configuring this feature. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.network.crypto.enabledfalse + Enable AES-based RPC encryption, including the new authentication protocol added in 2.2.0. +
      spark.network.crypto.keyLength128 + The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. +
      spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 + The key factory algorithm to use when generating encryption keys. Should be one of the + algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. +
      spark.network.crypto.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
      spark.network.crypto.saslFallbacktrue + Whether to fall back to SASL authentication if authentication fails using Spark's internal + mechanism. This is useful when the application is connecting to old shuffle services that + do not support the internal Spark authentication protocol. On the shuffle service side, + disabling this feature will block older clients from authenticating. +
      spark.authenticate.enableSaslEncryptionfalse + Enable SASL-based encrypted communication. +
      spark.network.sasl.serverAlwaysEncryptfalse + Disable unencrypted connections for ports using SASL authentication. This will deny connections + from clients that have authentication enabled, but do not request SASL-based encryption. +
      + + +# Local Storage Encryption + +Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle +spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover +encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or +`saveAsTable`. + +The following settings cover enabling encryption for data written to disk: + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.io.encryption.enabledfalse + Enable local disk I/O encryption. Currently supported by all modes except Mesos. It's strongly + recommended that RPC encryption be enabled when using this feature. +
      spark.io.encryption.keySizeBits128 + IO encryption key size in bits. Supported values are 128, 192 and 256. +
      spark.io.encryption.keygen.algorithmHmacSHA1 + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. +
      spark.io.encryption.commons.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
      + + +# Web UI + +## Authentication and Authorization + +Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +You will need a filter that implements the authentication method you want to deploy. Spark does not +provide any built-in authentication filters. + +Spark also supports access control to the UI when an authentication filter is present. Each +application can be configured with its own separate access control lists (ACLs). Spark +differentiates between "view" permissions (who is allowed to see the application's UI), and "modify" +permissions (who can do things like kill jobs in a running application). + +ACLs can be configured for either users or groups. Configuration entries accept comma-separated +lists as input, meaning multiple users or groups can be given the desired privileges. This can be +used if you run on a shared cluster and have a set of administrators or developers who need to +monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL +means that all users will have the respective pivilege. By default, only the user submitting the +application is added to the ACLs. + +Group membership is established by using a configurable group mapping provider. The mapper is +configured using the spark.user.groups.mapping config option, described in the table +below. + +The following options control the authentication of Web UIs: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.ui.filtersNone + See the Spark UI configuration for how to configure + filters. +
      spark.acls.enablefalse + Whether UI ACLs should be enabled. If enabled, this checks to see if the user has access + permissions to view or modify the application. Note this requires the user to be authenticated, + so if no authentication filter is installed, this option does not do anything. +
      spark.admin.aclsNone + Comma-separated list of users that have view and modify access to the Spark application. +
      spark.admin.acls.groupsNone + Comma-separated list of groups that have view and modify access to the Spark application. +
      spark.modify.aclsNone + Comma-separated list of users that have modify access to the Spark application. +
      spark.modify.acls.groupsNone + Comma-separated list of groups that have modify access to the Spark application. +
      spark.ui.view.aclsNone + Comma-separated list of users that have view access to the Spark application. +
      spark.ui.view.acls.groupsNone + Comma-separated list of groups that have view access to the Spark application. +
      spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider, which can be configured by + this property. + +
      By default, a Unix shell-based implementation is used, which collects this information + from the host OS. + +
      Note: This implementation supports only Unix/Linux-based environments. + Windows environment is currently not supported. However, a new platform/protocol can + be supported by implementing the trait mentioned above. +
      + +On YARN, the view and modify ACLs are provided to the YARN service when submitting applications, and +control who has the respective privileges via YARN interfaces. -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service -and the RPC endpoints. Shuffle files can also be encrypted if desired. +## Spark History Server ACLs -### SSL Configuration +Authentication for the SHS Web UI is enabled the same way as for regular applications, using +servlet filters. + +To enable authorization in the SHS, a few extra options are used: + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.history.ui.acls.enablefalse + Specifies whether ACLs should be checked to authorize users viewing the applications in + the history server. If enabled, access control checks are performed regardless of what the + individual applications had set for spark.ui.acls.enable. The application owner + will always have authorization to view their own application and any users specified via + spark.ui.view.acls and groups specified via spark.ui.view.acls.groups + when the application was run will also have authorization to view that application. + If disabled, no access control checks are made for any application UIs available through + the history server. +
      spark.history.ui.admin.aclsNone + Comma separated list of users that have view access to all the Spark applications in history + server. +
      spark.history.ui.admin.acls.groupsNone + Comma separated list of groups that have view access to all the Spark applications in history + server. +
      + +The SHS uses the same options to configure the group mapping provider as regular applications. +In this case, the group mapping provider will apply to all UIs server by the SHS, and individual +application configurations will be ignored. + +## SSL Configuration Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the -protocols without disabling the ability to configure each one individually. The common SSL settings -are at `spark.ssl` namespace in Spark configuration. The following table describes the -component-specific configuration namespaces used to override the default settings: +protocols without disabling the ability to configure each one individually. The following table +describes the the SSL configuration namespaces: @@ -45,8 +327,11 @@ component-specific configuration namespaces used to override the default setting - - + + @@ -62,49 +347,205 @@ component-specific configuration namespaces used to override the default setting
      Component
      spark.ssl.fsFile download client (used to download jars and files from HTTPS-enabled servers).spark.ssl + The default SSL configuration. These values will apply to all namespaces below, unless + explicitly overridden at the namespace level. +
      spark.ssl.ui
      -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). -SSL must be configured on each node and configured for each component involved in communication using the particular protocol. +The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be +replaced with one of the above namespaces. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      ${ns}.enabledfalseEnables SSL. When enabled, ${ns}.ssl.protocol is required.
      ${ns}.portNone + The port where the SSL service will listen on. + +
      The port must be defined within a specific namespace configuration. The default + namespace is ignored when reading this configuration. + +
      When not set, the SSL port will be derived from the non-SSL port for the + same service. A value of "0" will make the service bind to an ephemeral port. +
      ${ns}.enabledAlgorithmsNone + A comma-separated list of ciphers. The specified ciphers must be supported by JVM. + +
      The reference list of protocols can be found in the "JSSE Cipher Suite Names" section + of the Java security guide. The list for Java 8 can be found at + this + page. + +
      Note: If not set, the default cipher suite for the JRE will be used. +
      ${ns}.keyPasswordNone + The password to the private key in the key store. +
      ${ns}.keyStoreNone + Path to the key store file. The path can be absolute or relative to the directory in which the + process is started. +
      ${ns}.keyStorePasswordNonePassword to the key store.
      ${ns}.keyStoreTypeJKSThe type of the key store.
      ${ns}.protocolNone + TLS protocol to use. The protocol must be supported by JVM. + +
      The reference list of protocols can be found in the "Additional JSSE Standard Names" + section of the Java security guide. For Java 8, the list can be found at + this + page. +
      ${ns}.needClientAuthfalseWhether to require client authentication.
      ${ns}.trustStoreNone + Path to the trust store file. The path can be absolute or relative to the directory in which + the process is started. +
      ${ns}.trustStorePasswordNonePassword for the trust store.
      ${ns}.trustStoreTypeJKSThe type of the trust store.
      + +## Preparing the key stores + +Key stores can be generated by `keytool` program. The reference documentation for this tool for +Java 8 is [here](https://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html). +The most basic steps to configure the key stores and the trust store for a Spark Standalone +deployment mode is as follows: + +* Generate a key pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust store +* Distribute the trust store to the cluster nodes ### YARN mode -The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. -For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. +To provide a local trust store or key store file to drivers running in cluster mode, they can be +distributed with the application using the `--files` command line argument (or the equivalent +`spark.files` configuration). The files will be placed on the driver's working directory, so the TLS +configuration should just reference the file name with no absolute path. + +Distributing local key stores this way may require the files to be staged in HDFS (or other similar +distributed file system used by the cluster), so it's recommended that the undelying file system be +configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode -The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +The user needs to provide key stores and configuration options for master and workers. They have to +be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in +`SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. + +The user may allow the executors to use the SSL settings inherited from the worker process. That +can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that case, the settings +provided by the user on the client side are not used. ### Mesos mode -Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables. +Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based +secrets. Spark allows the specification of file-based and environment variable based secrets with +`spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. -### Preparing the key-stores -Key-stores can be generated by `keytool` program. The reference documentation for this tool is -[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic -steps to configure the key-stores and the trust-store for the standalone deployment mode is as -follows: +Depending on the secret store backend secrets can be passed by reference or by value with the +`spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, +respectively. -* Generate a keys pair for each node -* Export the public key of the key pair to a file on each node -* Import all exported public keys into a single trust-store -* Distribute the trust-store over the nodes +Reference type secrets are served by the secret store and referred to by name, for example +`/mysecret`. Value type secrets are passed on the command line and translated into their +appropriate files or environment variables. + +## HTTP Security Headers -### Configuring SASL Encryption +Apache Spark can be configured to include HTTP headers to aid in preventing Cross Site Scripting +(XSS), Cross-Frame Scripting (XFS), MIME-Sniffing, and also to enforce HTTP Strict Transport +Security. -SASL encryption is currently supported for the block transfer service when authentication -(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set -`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
        +
      • 0 (Disables XSS filtering)
      • +
      • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
      • +
      • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
      • +
      +
      spark.ui.xContentTypeOptions.enabledtrue + When enabled, X-Content-Type-Options HTTP response header will be set to "nosniff". +
      spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly. This option is only used when + SSL/TLS is enabled. +
        +
      • max-age=<expire-time>
      • +
      • max-age=<expire-time>; includeSubDomains
      • +
      • max-age=<expire-time>; preload
      • +
      +
      -When using an external shuffle service, it's possible to disable unencrypted connections by setting -`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that -option is enabled, applications that are not set up to use SASL encryption will fail to connect to -the shuffle service. -## Configuring Ports for Network Security +# Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight firewall settings. Below are the primary ports that Spark uses for its communication and how to configure those ports. -### Standalone mode only +## Standalone mode only @@ -145,7 +586,7 @@ configure those ports.
      -### All cluster managers +## All cluster managers @@ -186,54 +627,70 @@ configure those ports.
      -### HTTP Security Headers -Apache Spark can be configured to include HTTP Headers which aids in preventing Cross -Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP -Strict Transport Security. +# Kerberos + +Spark supports submitting applications in environments that use Kerberos for authentication. +In most cases, Spark relies on the credentials of the current logged in user when authenticating +to Kerberos-aware services. Such credentials can be obtained by logging in to the configured KDC +with tools like `kinit`. + +When talking to Hadoop-based services, Spark needs to obtain delegation tokens so that non-local +processes can authenticate. Spark ships with support for HDFS and other Hadoop file systems, Hive +and HBase. + +When using a Hadoop filesystem (such HDFS or WebHDFS), Spark will acquire the relevant tokens +for the service hosting the user's home directory. + +An HBase token will be obtained if HBase is in the application's classpath, and the HBase +configuration has Kerberos authentication turned (`hbase.security.authentication=kerberos`). + +Similarly, a Hive token will be obtained if Hive is in the classpath, and the configuration includes +URIs for remote metastore services (`hive.metastore.uris` is not empty). + +Delegation token support is currently only supported in YARN and Mesos modes. Consult the +deployment-specific page for more information. + +The following options provides finer-grained control for this feature: - - - - - - + - - - - -
      Property NameDefaultMeaning
      spark.ui.xXssProtection1; mode=block - Value for HTTP X-XSS-Protection response header. You can choose appropriate value - from below: -
        -
      • 0 (Disables XSS filtering)
      • -
      • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, - the browser will sanitize the page.)
      • -
      • 1; mode=block (Enables XSS filtering. The browser will prevent rendering - of the page if an attack is detected.)
      • -
      -
      spark.ui.xContentTypeOptions.enabledspark.security.credentials.${service}.enabled true - When value is set to "true", X-Content-Type-Options HTTP response header will be set - to "nosniff". Set "false" to disable. -
      spark.ui.strictTransportSecurityNone - Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate - value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. -
        -
      • max-age=<expire-time>
      • -
      • max-age=<expire-time>; includeSubDomains
      • -
      • max-age=<expire-time>; preload
      • -
      + Controls whether to obtain credentials for services when security is enabled. + By default, credentials for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run.
      - -See the [configuration page](configuration.html) for more details on the security configuration -parameters, and -org.apache.spark.SecurityManager for implementation details about security. +## Long-Running Applications + +Long-running applications may run into issues if their run time exceeds the maximum delegation +token lifetime configured in services it needs to access. + +Spark supports automatically creating new tokens for these applications when running in YARN mode. +Kerberos credentials need to be provided to the Spark application via the `spark-submit` command, +using the `--principal` and `--keytab` parameters. + +The provided keytab will be copied over to the machine running the Application Master via the Hadoop +Distributed Cache. For this reason, it's strongly recommended that both YARN and HDFS be secured +with encryption, at least. + +The Kerberos login will be periodically renewed using the provided credentials, and new delegation +tokens for supported will be created. + + +# Event Logging + +If your applications are using event logging, the directory where the event logs go +(`spark.eventLog.dir`) should be manually created with proper permissions. To secure the log files, +the directory permissions should be set to `drwxrwxrwxt`. The owner and group of the directory +should correspond to the super user who is running the Spark History Server. +This will allow all users to write to the directory but will prevent unprivileged users from +reading, removing or renaming a file unless they own it. The event log files will be created by +Spark with permissions such that only the user and group have read and write access. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 8fa643abf137..f06e72a387df 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -338,7 +338,7 @@ worker during one single schedule iteration. # Monitoring and Logging -Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. +Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. diff --git a/docs/sparkr.md b/docs/sparkr.md index 997ea60fb6cf..4faad2c4c182 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -107,7 +107,7 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically, we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
      {% highlight r %} @@ -169,7 +169,7 @@ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings {% endhighlight %}
      -The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example +The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example, we can save the SparkDataFrame from the previous example to a Parquet file using `write.df`.
      @@ -241,7 +241,7 @@ head(filter(df, df$waiting < 50)) ### Grouping, Aggregation -SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example, we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
      {% highlight r %} @@ -596,7 +596,7 @@ The following example shows how to save/load a MLlib model by SparkR. # Structured Streaming -SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) +SparkR supports the Structured Streaming API. Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) # R Function Name Conflicts @@ -663,3 +663,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. + +## Upgrading to SparkR 2.3.1 and above + + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d2..075b953a0898 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -165,7 +165,7 @@ In addition to simple column references and expressions, Datasets also have a ri
      -In Python it's possible to access a DataFrame's columns either by attribute +In Python, it's possible to access a DataFrame's columns either by attribute (`df.age`) or by indexing (`df['age']`). While the former is convenient for interactive data exploration, users are highly encouraged to use the latter form, which is future proof and won't break with column names that @@ -278,7 +278,7 @@ the bytes back into an object. Spark SQL supports two different methods for converting existing RDDs into Datasets. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This -reflection based approach leads to more concise code and works well when you already know the schema +reflection-based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating Datasets is through a programmatic interface that allows you to @@ -915,6 +915,14 @@ spark.catalog.refreshTable("my_table")
      +
      + +{% highlight r %} +refreshTable("my_table") +{% endhighlight %} + +
      +
      {% highlight sql %} @@ -953,8 +961,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo. @@ -994,6 +1004,29 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession +## ORC Files + +Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. +To do that, the following configurations are newly added. The vectorized reader is used for the +native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` +is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC +serde tables (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), +the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also set to `true`. + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.sql.orc.implhiveThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1.
      spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
      + ## JSON Datasets
      @@ -1181,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 1.2.1. + options are 0.12.0 through 2.3.2. @@ -1210,7 +1243,7 @@ The following options can be used to configure the version of Hive that is used com.mysql.jdbc,
      org.postgresql,
      com.microsoft.sqlserver,
      oracle.jdbc

      - A comma separated list of class prefixes that should be loaded using the classloader that is + A comma-separated list of class prefixes that should be loaded using the classloader that is shared between Spark SQL and a specific version of Hive. An example of classes that should be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need to be shared are those that interact with classes that are already shared. For example, @@ -1243,7 +1276,7 @@ provide a ClassTag. (Note that this is different than the Spark SQL JDBC server, which allows other applications to run queries using Spark SQL). -To get started you will need to include the JDBC driver for you particular database on the +To get started you will need to include the JDBC driver for your particular database on the spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: @@ -1408,7 +1441,7 @@ SELECT * FROM resultTable # Performance Tuning -For some workloads it is possible to improve performance by either caching data in memory, or by +For some workloads, it is possible to improve performance by either caching data in memory, or by turning on some experimental options. ## Caching Data In Memory @@ -1496,10 +1529,10 @@ that these options will be deprecated in future release as more optimizations ar ## Broadcast Hint for SQL Queries The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) support BHJ. When the broadcast nested loop join is selected, we still respect the hint.

      @@ -1630,8 +1663,155 @@ Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` a You may run `./bin/spark-sql --help` for a complete list of all available options. +# PySpark Usage Guide for Pandas with Apache Arrow + +## Apache Arrow in Spark + +Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer +data between JVM and Python processes. This currently is most beneficial to Python users that +work with Pandas/NumPy data. Its usage is not automatic and might require some minor +changes to configuration or code to take full advantage and ensure compatibility. This guide will +give a high-level description of how to use Arrow in Spark and highlight any differences when +working with Arrow-enabled data. + +### Ensure PyArrow Installed + +If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the +SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow +is installed and available on all cluster nodes. The current supported version is 0.8.0. +You can install using pip or conda from the conda-forge channel. See PyArrow +[installation](https://arrow.apache.org/docs/python/install.html) for details. + +## Enabling for Conversion to/from Pandas + +Arrow is available as an optimization when converting a Spark DataFrame to a Pandas DataFrame +using the call `toPandas()` and when creating a Spark DataFrame from a Pandas DataFrame with +`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set +the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. + +In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically +to non-Arrow optimization implementation if an error occurs before the actual computation within Spark. +This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'. + +
      +
      +{% include_example dataframe_with_arrow python/sql/arrow.py %} +
      +
      + +Using the above optimizations with Arrow will produce the same results as when Arrow is not +enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the +DataFrame to the driver program and should be done on a small subset of the data. Not all Spark +data types are currently supported and an error can be raised if a column has an unsupported type, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, +Spark will fall back to create the DataFrame without Arrow. + +## Pandas UDFs (a.k.a. Vectorized UDFs) + +Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and +Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator +or to wrap the function, no additional configuration is required. Currently, there are two types of +Pandas UDF: Scalar and Grouped Map. + +### Scalar + +Scalar Pandas UDFs are used for vectorizing scalar operations. They can be used with functions such +as `select` and `withColumn`. The Python function should take `pandas.Series` as inputs and return +a `pandas.Series` of the same length. Internally, Spark will execute a Pandas UDF by splitting +columns into batches and calling the function for each batch as a subset of the data, then +concatenating the results together. + +The following example shows how to create a scalar Pandas UDF that computes the product of 2 columns. + +
      +
      +{% include_example scalar_pandas_udf python/sql/arrow.py %} +
      +
      + +### Grouped Map +Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +Split-apply-combine consists of three steps: +* Split the data into groups by using `DataFrame.groupBy`. +* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The + input data contains all the rows and columns for each group. +* Combine the results into a new `DataFrame`. + +To use `groupBy().apply()`, the user needs to define the following: +* A Python function that defines the computation for each group. +* A `StructType` object or a string that defines the schema of the output `DataFrame`. + +Note that all data for a group will be loaded into memory before the function is applied. This can +lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user +to ensure that the grouped data will fit into the available memory. + +The following example shows how to use `groupby().apply()` to subtract the mean from each value in the group. + +
      +
      +{% include_example grouped_map_pandas_udf python/sql/arrow.py %} +
      +
      + +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and +[`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). + +## Usage Notes + +### Supported SQL Types + +Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`, +`ArrayType` of `TimestampType`, and nested `StructType`. + +### Setting Arrow Batch Size + +Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to +high memory usage in the JVM. To avoid possible out of memory exceptions, the size of the Arrow +record batches can be adjusted by setting the conf "spark.sql.execution.arrow.maxRecordsPerBatch" +to an integer that will determine the maximum number of rows for each batch. The default value is +10,000 records per batch. If the number of columns is large, the value should be adjusted +accordingly. Using this limit, each data partition will be made into 1 or more record batches for +processing. + +### Timestamp with Time Zone Semantics + +Spark internally stores timestamps as UTC values, and timestamp data that is brought in without +a specified time zone is converted as local time to UTC with microsecond resolution. When timestamp +data is exported or displayed in Spark, the session time zone is used to localize the timestamp +values. The session time zone is set with the configuration 'spark.sql.session.timeZone' and will +default to the JVM system local time zone if not set. Pandas uses a `datetime64` type with nanosecond +resolution, `datetime64[ns]`, with optional time zone on a per-column basis. + +When timestamp data is transferred from Spark to Pandas it will be converted to nanoseconds +and each column will be converted to the Spark session time zone then localized to that time +zone, which removes the time zone and displays values as local time. This will occur +when calling `toPandas()` or `pandas_udf` with timestamp columns. + +When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This +occurs when calling `createDataFrame` with a Pandas DataFrame or when returning a timestamp from a +`pandas_udf`. These conversions are done automatically to ensure Spark will have data in the +expected format, so it is not necessary to do any of these conversions yourself. Any nanosecond +values will be truncated. + +Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is +different than a Pandas timestamp. It is recommended to use Pandas time series functionality when +working with timestamps in `pandas_udf`s to get the best performance, see +[here](https://pandas.pydata.org/pandas-docs/stable/timeseries.html) for details. + # Migration Guide +## Upgrading From Spark SQL 2.3 to 2.4 + + - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. @@ -1778,14 +1958,22 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). + - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. ## Upgrading From Spark SQL 2.1 to 2.2 - - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). ## Upgrading From Spark SQL 2.0 to 2.1 @@ -1806,7 +1994,7 @@ options. transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ - APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the + APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the single-node data frame notion in these languages. - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union` @@ -1826,7 +2014,7 @@ options. ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC connection owns a copy of their own SQL configuration and temporary function registry. Cached tables are still shared though. If you prefer to run the Thrift server in the old single-session mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add @@ -1974,7 +2162,7 @@ been renamed to `DataFrame`. This is primarily because DataFrames no longer inhe directly, but instead provide most of the functionality that RDDs provide though their own implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. -In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. @@ -1982,12 +2170,12 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to -use types that are usable from both languages (i.e. `Array` instead of language specific collections). +of either language should use `SQLContext` and `DataFrame`. In general these classes try to +use types that are usable from both languages (i.e. `Array` instead of language-specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally, the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. @@ -2044,9 +2232,9 @@ referencing a singleton. ## Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. -Currently Hive SerDes and UDFs are based on Hive 1.2.1, +Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses @@ -2136,10 +2324,10 @@ A handful of Hive optimizations are not yet included in Spark. Some of these (su less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. -* Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you +* Block-level bitmap indexes and virtual columns (used to build indexes) +* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". -* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still +* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. * `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. @@ -2163,7 +2351,7 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. - + ### Incompatible Hive UDF Below are the scenarios in which Hive and Spark generate different results: @@ -2796,6 +2984,6 @@ does not exactly match standard floating point semantics. Specifically: - NaN = NaN returns true. - - In aggregations all NaN values are grouped together. + - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index f4bb2353e3c4..dacaa3438d48 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -39,10 +39,10 @@ For example, for Maven support, add the following to the pom.xml fi # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -The main category of parameters that should be configured are the authentication parameters +The main category of parameters that should be configured is the authentication parameters required by Keystone. -The following table contains a list of Keystone mandatory parameters. PROVIDER can be +The following table contains a list of Keystone mandatory parameters. PROVIDER can be any (alphanumeric) name. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 257a4f7d4f3c..a1b6942ffe0a 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -17,7 +17,7 @@ Choose a machine in your cluster such that - Flume can be configured to push data to a port on that machine. -Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data. #### Configuring Flume Configure Flume agent to send data to an Avro sink by having the following in the configuration file. @@ -100,7 +100,7 @@ Choose a machine that will run the custom sink in a Flume agent. The rest of the #### Configuring Flume Configuring Flume on the chosen machine requires the following two steps. -1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink. (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). @@ -128,7 +128,7 @@ Configuring Flume on the chosen machine requires the following two steps. agent.sinks.spark.port = agent.sinks.spark.channel = memoryChannel - Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about configuring Flume agents. diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md index 9f0671da2ee3..becf217738d2 100644 --- a/docs/streaming-kafka-0-8-integration.md +++ b/docs/streaming-kafka-0-8-integration.md @@ -10,7 +10,7 @@ Here we explain how to configure Spark Streaming to receive data from Kafka. The ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write-Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write-ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write-Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -55,11 +55,11 @@ Next, we discuss how to use this approach in your streaming application. **Points to remember:** - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + - Topic partitions in Kafka do not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write-Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). 3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. @@ -80,9 +80,9 @@ This approach has the following advantages over the receiver-based approach (i.e - *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write-Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write-Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write-Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high-level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with-write-ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 868acc41226d..c30959263cdf 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -74,7 +74,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. -// The master requires 2 cores to prevent from a starvation scenario. +// The master requires 2 cores to prevent a starvation scenario. val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) @@ -172,7 +172,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Note that we defined the transformation using a [FlatMapFunction](api/scala/index.html#org.apache.spark.api.java.function.FlatMapFunction) object. As we will discover along the way, there are a number of such convenience classes in the Java API -that help define DStream transformations. +that help defines DStream transformations. Next, we want to count these words. @@ -1461,7 +1461,7 @@ Note that the connections in the pool should be lazily created on demand and tim *** ## DataFrame and SQL Operations -You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore, this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.
      @@ -2010,10 +2010,10 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *Configuring write ahead logs* - Since Spark 1.2, - we have introduced _write ahead logs_ for achieving strong +- *Configuring write-ahead logs* - Since Spark 1.2, + we have introduced _write-ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into - a write ahead log in the configuration checkpoint directory. This prevents data loss on driver + a write-ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) @@ -2021,15 +2021,15 @@ To run a Spark Streaming applications, you need to have the following. come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the - received data within Spark be disabled when the write ahead log is enabled as the log is already + received data within Spark be disabled when the write-ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that - does not support flushing) for _write ahead logs_, please remember to enable + does not support flushing) for _write-ahead logs_, please remember to enable `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - Note that Spark will not encrypt data written to the write ahead log when I/O encryption is - enabled. If encryption of the write ahead log data is desired, it should be stored in a file + Note that Spark will not encrypt data written to the write-ahead log when I/O encryption is + enabled. If encryption of the write-ahead log data is desired, it should be stored in a file system that supports encryption natively. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming @@ -2284,9 +2284,9 @@ Having bigger blockinterval means bigger blocks. A high value of `spark.locality - Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued. -- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted. +- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However, the partitioning of the RDDs is not impacted. -- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. +- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently, there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. *************************************************************************************************** @@ -2388,7 +2388,7 @@ then besides these losses, all of the past data that was received and replicated lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +ahead logs_ which save the received data to fault-tolerant storage. With the [write-ahead logs enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2402,7 +2402,7 @@ The following table summarizes the semantics under failures:
      - + + + + + + @@ -1500,7 +1946,7 @@ aggDF \ .format("console") \ .start() -# Have all the aggregates in an in memory table. The query name will be the table name +# Have all the aggregates in an in-memory table. The query name will be the table name aggDF \ .writeStream \ .queryName("aggregates") \ @@ -1543,7 +1989,7 @@ head(sql("select * from aggregates")) -#### Using Foreach +##### Using Foreach The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. @@ -1560,6 +2006,172 @@ which has methods that get called whenever there is a sequence of rows generated - Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. +#### Triggers +The trigger settings of a streaming query defines the timing of streaming data processing, whether +the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query. +Here are the different kinds of triggers that are supported. + +
      Spark 1.1 or earlier, OR
      - Spark 1.2 or later without write ahead logs + Spark 1.2 or later without write-ahead logs
      Buffered data lost with unreliable receivers
      @@ -2416,7 +2416,7 @@ The following table summarizes the semantics under failures:
      Spark 1.2 or later with write ahead logsSpark 1.2 or later with write-ahead logs Zero data loss with reliable receivers
      At-least once semantics diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index bab0be8ddeb9..71fd5b10cc40 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -15,7 +15,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also see the [Deploying](#deploying) subsection below. +For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also, see the [Deploying](#deploying) subsection below. ## Reading Data from Kafka @@ -61,7 +61,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -70,7 +70,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -79,7 +79,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -125,7 +125,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") ### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, -you can create an Dataset/DataFrame for a defined range of offsets. +you can create a Dataset/DataFrame for a defined range of offsets.
      @@ -171,7 +171,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -180,7 +180,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -191,7 +191,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -597,7 +597,7 @@ Note that the following Kafka params cannot be set and the Kafka source or sink - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. - **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use -DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. +DataFrame operations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 31fcfabb9cac..602a4c70848e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,9 +8,11 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write-Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. + +In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in @@ -477,7 +479,7 @@ detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) -to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. +to track the read position in the stream. The engine uses checkpointing and write-ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` @@ -688,7 +690,7 @@ These examples generate streaming DataFrames that are untyped, meaning that the By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. -Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user-provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. @@ -827,8 +829,8 @@ df.isStreaming() {% endhighlight %}
      -{% highlight bash %} -Not available. +{% highlight r %} +isStreaming(df) {% endhighlight %}
      @@ -885,11 +887,24 @@ windowedCounts = words.groupBy( ).count() {% endhighlight %} + +
      +{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
      -### Handling Late Data and Watermarking +#### Handling Late Data and Watermarking Now consider what happens if one of the events arrives late to the application. For example, say, a word generated at 12:04 (i.e. event time) could be received by the application at 12:11. The application should use the time 12:04 instead of 12:11 @@ -910,7 +925,9 @@ specifying the event time column and the threshold on how late the data is expec event time. For a specific window starting at time `T`, the engine will maintain state and allow late data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, -but data later than the threshold will be dropped. Let's understand this with an example. We can +but data later than the threshold will start getting dropped +(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below.
      @@ -959,6 +976,21 @@ windowedCounts = words \ .count() {% endhighlight %} +
      +
      +{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group + +words <- withWatermark(words, "timestamp", "10 minutes") +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
      @@ -1001,7 +1033,9 @@ then drops intermediate state of a window < watermark, and appends the final counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is appended to the Result Table only after the watermark is updated to `12:11`. -**Conditions for watermarking to clean aggregation state** +##### Conditions for watermarking to clean aggregation state +{:.no_toc} + It is important to note that the following conditions must be satisfied for the watermarking to clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*. @@ -1021,9 +1055,31 @@ from the aggregation column. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append output mode. +##### Semantic Guarantees of Aggregation with Watermarking +{:.no_toc} + +- A watermark delay (set with `withWatermark`) of "2 hours" guarantees that the engine will never +drop any data that is less than 2 hours delayed. In other words, any data less than 2 hours behind +(in terms of event-time) the latest data processed till then is guaranteed to be aggregated. + +- However, the guarantee is strict only in one direction. Data delayed by more than 2 hours is +not guaranteed to be dropped; it may or may not get aggregated. More delayed is the data, less +likely is the engine going to process it. ### Join Operations -Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples. +Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame +as well as another streaming Dataset/DataFrame. The result of the streaming join is generated +incrementally, similar to the results of streaming aggregations in the previous section. In this +section we will explore what type of joins (i.e. inner, outer, etc.) are supported in the above +cases. Note that in all the supported join types, the result of the join with a streaming +Dataset/DataFrame will be the exactly the same as if it was with a static Dataset/DataFrame +containing the same data in the stream. + + +#### Stream-static Joins + +Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some +type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example.
      @@ -1058,9 +1114,382 @@ streamingDf.join(staticDf, "type") # inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") # right outer join with a static DF {% endhighlight %} +
      + +
      + +{% highlight r %} +staticDf <- read.df(...) +streamingDf <- read.stream(...) +joined <- merge(streamingDf, staticDf, sort = FALSE) # inner equi-join with a static DF +joined <- join( + staticDf, + streamingDf, + streamingDf$value == staticDf$value, + "right_outer") # right outer join with a static DF +{% endhighlight %} + +
      +
      + +Note that stream-static joins are not stateful, so no state management is necessary. +However, a few types of stream-static outer joins are not yet supported. +These are listed at the [end of this Join section](#support-matrix-for-joins-in-streaming-queries). + +#### Stream-stream Joins +In Spark 2.3, we have added support for stream-stream joins, that is, you can join two streaming +Datasets/DataFrames. The challenge of generating join results between two data streams is that, +at any point of time, the view of the dataset is incomplete for both sides of the join making +it much harder to find matches between inputs. Any row received from one input stream can match +with any future, yet-to-be-received row from the other input stream. Hence, for both the input +streams, we buffer past input as streaming state, so that we can match every future input with +past input and accordingly generate joined results. Furthermore, similar to streaming aggregations, +we automatically handle late, out-of-order data and can limit the state using watermarks. +Let’s discuss the different types of supported stream-stream joins and how to use them. + +##### Inner Joins with optional Watermarking +Inner joins on any kind of columns along with any kind of join conditions are supported. +However, as the stream runs, the size of streaming state will keep growing indefinitely as +*all* past input must be saved as any new input can match with any input from the past. +To avoid unbounded state, you have to define additional join conditions such that indefinitely +old inputs cannot match with future inputs and therefore can be cleared from the state. +In other words, you will have to do the following additional steps in the join. + +1. Define watermark delays on both inputs such that the engine knows how delayed the input can be +(similar to streaming aggregations) + +1. Define a constraint on event-time across the two inputs such that the engine can figure out when +old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for +matches with the other input. This constraint can be defined in one of the two ways. + + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + + 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). + +Let’s understand this with an example. + +Let’s say we want to join a stream of advertisement impressions (when an ad was shown) with +another stream of user clicks on advertisements to correlate when impressions led to +monetizable clicks. To allow the state cleanup in this stream-stream join, you will have to +specify the watermarking delays and the time constraints as follows. + +1. Watermark delays: Say, the impressions and the corresponding clicks can be late/out-of-order +in event-time by at most 2 and 3 hours, respectively. + +1. Event-time range condition: Say, a click can occur within a time range of 0 seconds to 1 hour +after the corresponding impression. + +The code would look like this. + +
      +
      + +{% highlight scala %} +import org.apache.spark.sql.functions.expr + +val impressions = spark.readStream. ... +val clicks = spark.readStream. ... + +// Apply watermarks on event-time columns +val impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +val clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
      +
      + +{% highlight java %} +import static org.apache.spark.sql.functions.expr + +Dataset impressions = spark.readStream(). ... +Dataset clicks = spark.readStream(). ... + +// Apply watermarks on event-time columns +Dataset impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours"); +Dataset clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours"); + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour ") +); + +{% endhighlight %} + + +
      +
      + +{% highlight python %} +from pyspark.sql.functions import expr + +impressions = spark.readStream. ... +clicks = spark.readStream. ... + +# Apply watermarks on event-time columns +impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +# Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
      +
      + +{% highlight r %} +impressions <- read.stream(...) +clicks <- read.stream(...) + +# Apply watermarks on event-time columns +impressionsWithWatermark <- withWatermark(impressions, "impressionTime", "2 hours") +clicksWithWatermark <- withWatermark(clicks, "clickTime", "3 hours") + +# Join with event-time constraints +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour" +))) + +{% endhighlight %} + +
      +
      + +###### Semantic Guarantees of Stream-stream Inner Joins with Watermarking +{:.no_toc} +This is similar to the [guarantees provided by watermarking on aggregations](#semantic-guarantees-of-aggregation-with-watermarking). +A watermark delay of "2 hours" guarantees that the engine will never drop any data that is less than + 2 hours delayed. But data delayed by more than 2 hours may or may not get processed. + +##### Outer Joins with Watermarking +While the watermark + event-time constraints is optional for inner joins, for left and right outer +joins they must be specified. This is because for generating the NULL results in outer join, the +engine must know when an input row is not going to match with anything in future. Hence, the +watermark + event-time constraints must be specified for generating correct results. Therefore, +a query with outer-join will look quite like the ad-monetization example earlier, except that +there will be an additional parameter specifying it to be an outer-join. + +
      +
      + +{% highlight scala %} + +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + joinType = "leftOuter" // can be "inner", "leftOuter", "rightOuter" + ) + +{% endhighlight %} + +
      +
      + +{% highlight java %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour "), + "leftOuter" // can be "inner", "leftOuter", "rightOuter" +); + +{% endhighlight %} + + +
      +
      + +{% highlight python %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + "leftOuter" # can be "inner", "leftOuter", "rightOuter" +) + +{% endhighlight %} + +
      +
      + +{% highlight r %} +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour"), + "left_outer" # can be "inner", "left_outer", "right_outer" +)) + +{% endhighlight %} +
      + +###### Semantic Guarantees of Stream-stream Outer Joins with Watermarking +{:.no_toc} +Outer joins have the same guarantees as [inner joins](#semantic-guarantees-of-stream-stream-inner-joins-with-watermarking) +regarding watermark delays and whether data will be dropped or not. + +###### Caveats +{:.no_toc} +There are a few important characteristics to note regarding how the outer results are generated. + +- *The outer NULL results will be generated with a delay that depends on the specified watermark +delay and the time range condition.* This is because the engine has to wait for that long to ensure +there were no matches and there will be no more matches in future. + +- In the current implementation in the micro-batch engine, watermarks are advanced at the end of a +micro-batch, and the next micro-batch uses the updated watermark to clean up state and output +outer results. Since we trigger a micro-batch only when there is new data to be processed, the +generation of the outer result may get delayed if there no new data being received in the stream. +*In short, if any of the two input streams being joined does not receive data for a while, the +outer (both cases, left or right) output may get delayed.* + +##### Support matrix for joins in streaming queries + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Left InputRight InputJoin Type
      StaticStaticAll types + Supported, since its not on streaming data even though it + can be present in a streaming query +
      StreamStaticInnerSupported, not stateful
      Left OuterSupported, not stateful
      Right OuterNot supported
      Full OuterNot supported
      StaticStreamInnerSupported, not stateful
      Left OuterNot supported
      Right OuterSupported, not stateful
      Full OuterNot supported
      StreamStreamInner + Supported, optionally specify watermark on both sides + + time constraints for state cleanup +
      Left Outer + Conditionally supported, must specify watermark on right + time constraints for correct + results, optionally specify watermark on left for all state cleanup +
      Right Outer + Conditionally supported, must specify watermark on left + time constraints for correct + results, optionally specify watermark on right for all state cleanup +
      Full OuterNot supported
      + +Additional details on supported joins: + +- Joins can be cascaded, that is, you can do `df1.join(df2, ...).join(df3, ...).join(df4, ....)`. + +- As of Spark 2.3, you can use joins only when the query is in Append output mode. Other output modes are not yet supported. + +- As of Spark 2.3, you cannot use other non-map-like operations before joins. Here are a few examples of + what cannot be used. + + - Cannot use streaming aggregations before joins. + + - Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins. + + ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. @@ -1105,15 +1534,29 @@ streamingDf {% highlight python %} streamingDf = spark.readStream. ... -// Without watermark using guid column +# Without watermark using guid column streamingDf.dropDuplicates("guid") -// With watermark using guid and eventTime columns +# With watermark using guid and eventTime columns streamingDf \ .withWatermark("eventTime", "10 seconds") \ .dropDuplicates("guid", "eventTime") {% endhighlight %} + +
      + +{% highlight r %} +streamingDf <- read.stream(...) + +# Without watermark using guid column +streamingDf <- dropDuplicates(streamingDf, "guid") + +# With watermark using guid and eventTime columns +streamingDf <- withWatermark(streamingDf, "eventTime", "10 seconds") +streamingDf <- dropDuplicates(streamingDf, "guid", "eventTime") +{% endhighlight %} +
      @@ -1132,15 +1575,9 @@ Some of them are as follows. - Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode. -- Outer joins between a streaming and a static Datasets are conditionally supported. - - + Full outer join with a streaming Dataset is not supported - - + Left outer join with a streaming Dataset on the right is not supported - - + Right outer join with a streaming Dataset on the left is not supported - -- Any kind of joins between two streaming Datasets is not yet supported. +- Few types of outer joins on streaming Datasets are not supported. See the + support matrix in the Join Operations section + for more details. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -1248,6 +1685,15 @@ Here is the compatibility matrix. Aggregations not allowed after flatMapGroupsWithState.
      Queries with joinsAppend + Update and Complete mode not supported yet. See the + support matrix in the Join Operations section + for more details on what types of joins are supported. +
      Other queries Append, Update
      + + + + + + + + + + + + + + + + + + + + +
      Trigger TypeDescription
      unspecified (default) + If no trigger setting is explicitly specified, then by default, the query will be + executed in micro-batch mode, where micro-batches will be generated as soon as + the previous micro-batch has completed processing. +
      Fixed interval micro-batches + The query will be executed with micro-batches mode, where micro-batches will be kicked off + at the user-specified intervals. +
        +
      • If the previous micro-batch completes within the interval, then the engine will wait until + the interval is over before kicking off the next micro-batch.
      • + +
      • If the previous micro-batch takes longer than the interval to complete (i.e. if an + interval boundary is missed), then the next micro-batch will start as soon as the + previous one completes (i.e., it will not wait for the next interval boundary).
      • + +
      • If no new data is available, then no micro-batch will be kicked off.
      • +
      +
      One-time micro-batch + The query will execute *only one* micro-batch to process all the available data and then + stop on its own. This is useful in scenarios you want to periodically spin up a cluster, + process everything that is available since the last period, and then shutdown the + cluster. In some case, this may lead to significant cost savings. +
      Continuous with fixed checkpoint interval
      (experimental)
      + The query will be executed in the new low-latency, continuous processing mode. Read more + about this in the Continuous Processing section below. +
      + +Here are a few code examples. + +
      +
      + +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start() + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start() + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start() + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start() + +{% endhighlight %} + + +
      +
      + +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start(); + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start(); + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start(); + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start(); + +{% endhighlight %} + +
      +
      + +{% highlight python %} + +# Default trigger (runs micro-batch as soon as it can) +df.writeStream \ + .format("console") \ + .start() + +# ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream \ + .format("console") \ + .trigger(processingTime='2 seconds') \ + .start() + +# One-time trigger +df.writeStream \ + .format("console") \ + .trigger(once=True) \ + .start() + +# Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(continuous='1 second') + .start() + +{% endhighlight %} +
      +
      + +{% highlight r %} +# Default trigger (runs micro-batch as soon as it can) +write.stream(df, "console") + +# ProcessingTime trigger with two-seconds micro-batch interval +write.stream(df, "console", trigger.processingTime = "2 seconds") + +# One-time trigger +write.stream(df, "console", trigger.once = TRUE) + +# Continuous trigger is not yet supported +{% endhighlight %} +
      +
      + + ## Managing Streaming Queries The `StreamingQuery` object created when a query is started can be used to monitor and manage the query. @@ -2049,7 +2661,7 @@ sql("SET spark.sql.streaming.metricsEnabled=true") All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write-ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
      @@ -2097,6 +2709,107 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat
      +# Continuous Processing +## [Experimental] +{:.no_toc} + +**Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). + +To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, + +
      +
      +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +spark + .readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("") + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start() +{% endhighlight %} +
      +
      +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger; + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start(); +{% endhighlight %} +
      +
      +{% highlight python %} +spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .load() \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .trigger(continuous="1 second") \ # only change in query + .start() + +{% endhighlight %} +
      +
      + +A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. + +## Supported Queries +{:.no_toc} + +As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. + +- *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). + + All SQL functions are supported except aggregation functions (since aggregations are not yet supported), `current_timestamp()` and `current_date()` (deterministic computations using time is challenging). + +- *Sources*: + + Kafka source: All options are supported. + + Rate source: Good for testing. Only options that are supported in the continuous mode are `numPartitions` and `rowsPerSecond`. + +- *Sinks*: + + Kafka sink: All options are supported. + + Memory sink: Good for debugging. + + Console sink: Good for debugging. All options are supported. Note that the console will print every checkpoint interval that you have specified in the continuous trigger. + +See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. + +## Caveats +{:.no_toc} + +- Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. +- Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. +- There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. + # Additional Information **Further Reading** @@ -2114,6 +2827,11 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat **Talks** -- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) -- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) +- Spark Summit Europe 2017 + - Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark - + [Part 1 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark), [Part 2 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark-continues) + - Deep Dive into Stateful Stream Processing in Structured Streaming - [slides/video](https://databricks.com/session/deep-dive-into-stateful-stream-processing-in-structured-streaming) +- Spark Summit 2016 + - A Deep Dive into Structured Streaming - [slides/video](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 0473ab73a5e6..77aa083c4a58 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -5,7 +5,7 @@ title: Submitting Applications The `spark-submit` script in Spark's `bin` directory is used to launch applications on a cluster. It can use all of Spark's supported [cluster managers](cluster-overview.html#cluster-manager-types) -through a uniform interface so you don't have to configure your application specially for each one. +through a uniform interface so you don't have to configure your application especially for each one. # Bundling Your Application's Dependencies If your code depends on other projects, you will need to package them alongside @@ -58,7 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +the drivers and the executors. Currently, the standalone mode does not support cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, @@ -68,7 +68,7 @@ There are a few options available that are specific to the [cluster manager](cluster-overview.html#cluster-manager-types) that is being used. For example, with a [Spark standalone cluster](spark-standalone.html) with `cluster` deploy mode, you can also specify `--supervise` to make sure that the driver is automatically restarted if it -fails with non-zero exit code. To enumerate all such options available to `spark-submit`, +fails with a non-zero exit code. To enumerate all such options available to `spark-submit`, run it with `--help`. Here are a few examples of common options: {% highlight bash %} @@ -177,7 +177,7 @@ The master URL passed to Spark can be in one of the following formats: # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a -properties file and pass them on to your application. By default it will read options +properties file and pass them on to your application. By default, it will read options from `conf/spark-defaults.conf` in the Spark directory. For more detail, see the section on [loading default configurations](configuration.html#loading-default-configurations). @@ -192,7 +192,7 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included in the driver and executor classpaths. Directory expansion does not work with `--jars`. Spark uses the following URL scheme to allow different strategies for disseminating jars: diff --git a/docs/tuning.md b/docs/tuning.md index fc27713f28d4..912c39879be8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -196,7 +196,7 @@ To further tune garbage collection, we first need to understand some basic infor * A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old - enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked. + enough or Survivor2 is full, it is moved to Old. Finally, when Old is close to full, a full GC is invoked. The goal of GC tuning in Spark is to ensure that only long-lived RDDs are stored in the Old generation and that the Young generation is sufficiently sized to store short-lived objects. This will help avoid full GCs to collect diff --git a/examples/pom.xml b/examples/pom.xml index 1791dbaad775..868110b8e35e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java index dee56799d8ae..1529da16f051 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -18,10 +18,9 @@ package org.apache.spark.examples.ml; // $example on$ -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -50,7 +49,7 @@ public static void main(String[] args) { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -58,21 +57,15 @@ public static void main(String[] args) { System.out.println(lossPerIteration); } - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary - // classification problem. - BinaryLogisticRegressionSummary binarySummary = - (BinaryLogisticRegressionSummary) trainingSummary; - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - Dataset roc = binarySummary.roc(); + Dataset roc = trainingSummary.roc(); roc.show(); roc.select("FPR").show(); - System.out.println(binarySummary.areaUnderROC()); + System.out.println(trainingSummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. - Dataset fMeasure = binarySummary.fMeasureByThreshold(); + Dataset fMeasure = trainingSummary.fMeasureByThreshold(); double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) .select("threshold").head().getDouble(0); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java index da410cba2b3f..801a82cd2f24 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -48,6 +49,67 @@ public static void main(String[] args) { // Print the coefficients and intercept for multinomial logistic regression System.out.println("Coefficients: \n" + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector()); + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // for multiclass, we can inspect metrics on a per-label basis + System.out.println("False positive rate by label:"); + int i = 0; + double[] fprLabel = trainingSummary.falsePositiveRateByLabel(); + for (double fpr : fprLabel) { + System.out.println("label " + i + ": " + fpr); + i++; + } + + System.out.println("True positive rate by label:"); + i = 0; + double[] tprLabel = trainingSummary.truePositiveRateByLabel(); + for (double tpr : tprLabel) { + System.out.println("label " + i + ": " + tpr); + i++; + } + + System.out.println("Precision by label:"); + i = 0; + double[] precLabel = trainingSummary.precisionByLabel(); + for (double prec : precLabel) { + System.out.println("label " + i + ": " + prec); + i++; + } + + System.out.println("Recall by label:"); + i = 0; + double[] recLabel = trainingSummary.recallByLabel(); + for (double rec : recLabel) { + System.out.println("label " + i + ": " + rec); + i++; + } + + System.out.println("F-measure by label:"); + i = 0; + double[] fLabel = trainingSummary.fMeasureByLabel(); + for (double f : fLabel) { + System.out.println("label " + i + ": " + f); + i++; + } + + double accuracy = trainingSummary.accuracy(); + double falsePositiveRate = trainingSummary.weightedFalsePositiveRate(); + double truePositiveRate = trainingSummary.weightedTruePositiveRate(); + double fMeasure = trainingSummary.weightedFMeasure(); + double precision = trainingSummary.weightedPrecision(); + double recall = trainingSummary.weightedRecall(); + System.out.println("Accuracy: " + accuracy); + System.out.println("FPR: " + falsePositiveRate); + System.out.println("TPR: " + truePositiveRate); + System.out.println("F-measure: " + fMeasure); + System.out.println("Precision: " + precision); + System.out.println("Recall: " + recall); // $example off$ spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java similarity index 62% rename from examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java index 99af37676ba9..6f93cff94b72 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java @@ -23,9 +23,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.ml.feature.OneHotEncoderEstimator; +import org.apache.spark.ml.feature.OneHotEncoderModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -35,41 +34,37 @@ import org.apache.spark.sql.types.StructType; // $example off$ -public class JavaOneHotEncoderExample { +public class JavaOneHotEncoderEstimatorExample { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("JavaOneHotEncoderExample") + .appName("JavaOneHotEncoderEstimatorExample") .getOrCreate(); + // Note: categorical features are usually first encoded with StringIndexer // $example on$ List data = Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") + RowFactory.create(0.0, 1.0), + RowFactory.create(1.0, 0.0), + RowFactory.create(2.0, 1.0), + RowFactory.create(0.0, 2.0), + RowFactory.create(0.0, 1.0), + RowFactory.create(2.0, 0.0) ); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset df = spark.createDataFrame(data, schema); - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - Dataset indexed = indexer.transform(df); + OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - - Dataset encoded = encoder.transform(indexed); + OneHotEncoderModel model = encoder.fit(df); + Dataset encoded = model.transform(df); encoded.show(); // $example off$ diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac62110..43cc30c1a899 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java new file mode 100644 index 000000000000..d649a2ccbaa7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.ml.feature.VectorSizeHint; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorSizeHintExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorSizeHintExample") + .getOrCreate(); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row0 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + Row row1 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0); + Dataset dataset = spark.createDataFrame(Arrays.asList(row0, row1), schema); + + VectorSizeHint sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3); + + Dataset datasetWithSize = sizeHint.transform(dataset); + System.out.println("Rows where 'userFeatures' is not the right size are filtered out"); + datasetWithSize.show(false); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + // This dataframe can be used by downstream transformers as before + Dataset output = assembler.transform(datasetWithSize); + System.out.println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column " + + "'features'"); + output.select("features", "clicked").show(false); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4422f9e7a958..a18722c687f8 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,13 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from functools import reduce -from pyspark.sql import SparkSession - """ Read data file users.avro in local Spark distro: @@ -50,6 +43,13 @@ {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} """ +from __future__ import print_function + +import sys + +from functools import reduce +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: print(""" @@ -61,7 +61,7 @@ Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 92e0a3ae2ee6..a42d711fc505 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -49,7 +49,7 @@ def closestPoint(p, centers): if len(sys.argv) != 4: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 01c938454b10..bcc4e0f4e8ea 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -48,7 +48,7 @@ def readPointBatch(iterator): if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py index 2f0ca995e55c..0a71f76418ea 100644 --- a/examples/src/main/python/ml/aft_survival_regression.py +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating aft survival regression. +Run with: + bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py +""" from __future__ import print_function # $example on$ @@ -23,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating aft survival regression. -Run with: - bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index 1263cb5d177a..7842d2009e23 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating bisecting k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating bisecting k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py index 1b7a458125ce..610176ea596c 100644 --- a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +++ b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating BucketedRandomProjectionLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating BucketedRandomProjectionLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/chi_square_test_example.py b/examples/src/main/python/ml/chi_square_test_example.py index 8f25318ded00..2af7e683cdb7 100644 --- a/examples/src/main/python/ml/chi_square_test_example.py +++ b/examples/src/main/python/ml/chi_square_test_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for Chi-square hypothesis testing. +Run with: + bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -23,11 +28,6 @@ from pyspark.ml.stat import ChiSquareTest # $example off$ -""" -An example for Chi-square hypothesis testing. -Run with: - bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/correlation_example.py b/examples/src/main/python/ml/correlation_example.py index 0a9d30da5a42..1f4e402ac1a5 100644 --- a/examples/src/main/python/ml/correlation_example.py +++ b/examples/src/main/python/ml/correlation_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for computing correlation matrix. +Run with: + bin/spark-submit examples/src/main/python/ml/correlation_example.py +""" from __future__ import print_function # $example on$ @@ -23,11 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example for computing correlation matrix. -Run with: - bin/spark-submit examples/src/main/python/ml/correlation_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index db7054307c2e..6256d11504af 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" from __future__ import print_function # $example on$ @@ -26,14 +33,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating model selection using CrossValidator. -This example also demonstrates how Pipelines are Estimators. -Run with: - - bin/spark-submit examples/src/main/python/ml/cross_validator.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index 109f901012c9..cabc3de68f2f 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -17,7 +17,7 @@ """ An example of how to use DataFrame for ML. Run with:: - bin/spark-submit examples/src/main/python/ml/dataframe_example.py + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -33,20 +33,20 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Usage: dataframe_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) elif len(sys.argv) == 2: - input = sys.argv[1] + input_path = sys.argv[1] else: - input = "data/mllib/sample_libsvm_data.txt" + input_path = "data/mllib/sample_libsvm_data.txt" spark = SparkSession \ .builder \ .appName("DataFrameExample") \ .getOrCreate() - # Load input data - print("Loading LIBSVM file with UDT from " + input + ".") - df = spark.read.format("libsvm").load(input).cache() + # Load an input file + print("Loading LIBSVM file with UDT from " + input_path + ".") + df = spark.read.format("libsvm").load(input_path).cache() print("Schema from LIBSVM:") df.printSchema() print("Loaded training data as a DataFrame with " + diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py index c92c3c27abb2..39092e616d42 100644 --- a/examples/src/main/python/ml/fpgrowth_example.py +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.fpm import FPGrowth -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating FPGrowth. Run with: bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py """ +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index e4a0d314e9d9..4938a904189f 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Gaussian Mixture Model (GMM). +Run with: + bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating Gaussian Mixture Model (GMM). -Run with: - bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/generalized_linear_regression_example.py b/examples/src/main/python/ml/generalized_linear_regression_example.py index 796752a60f3a..a52f4650c1c6 100644 --- a/examples/src/main/python/ml/generalized_linear_regression_example.py +++ b/examples/src/main/python/ml/generalized_linear_regression_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating generalized linear regression. +Run with: + bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.ml.regression import GeneralizedLinearRegression # $example off$ -""" -An example demonstrating generalized linear regression. -Run with: - bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py index b8437f827e56..9ba014776361 100644 --- a/examples/src/main/python/ml/imputer_example.py +++ b/examples/src/main/python/ml/imputer_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.feature import Imputer -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating Imputer. Run with: bin/spark-submit examples/src/main/python/ml/imputer_example.py """ +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/isotonic_regression_example.py b/examples/src/main/python/ml/isotonic_regression_example.py index 6ae15f1b4b0d..89cba9dfc7e8 100644 --- a/examples/src/main/python/ml/isotonic_regression_example.py +++ b/examples/src/main/python/ml/isotonic_regression_example.py @@ -17,6 +17,9 @@ """ Isotonic Regression Example. + +Run with: + bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py """ from __future__ import print_function @@ -25,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating isotonic regression. -Run with: - bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 5f77843e3743..80a878af679f 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +An example demonstrating k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/kmeans_example.py + +This example requires NumPy (http://www.numpy.org/). +""" from __future__ import print_function # $example on$ @@ -24,14 +31,6 @@ from pyspark.sql import SparkSession -""" -An example demonstrating k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/kmeans_example.py - -This example requires NumPy (http://www.numpy.org/). -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py index a8b346f72cd6..97d1a042d147 100644 --- a/examples/src/main/python/ml/lda_example.py +++ b/examples/src/main/python/ml/lda_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating LDA. +Run with: + bin/spark-submit examples/src/main/python/ml/lda_example.py +""" from __future__ import print_function # $example on$ @@ -23,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating LDA. -Run with: - bin/spark-submit examples/src/main/python/ml/lda_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py index bd440a1fbe8d..2274ff707b2a 100644 --- a/examples/src/main/python/ml/logistic_regression_summary_example.py +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating Logistic Regression Summary. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating Logistic Regression Summary. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/min_hash_lsh_example.py b/examples/src/main/python/ml/min_hash_lsh_example.py index 7b1dd611a865..93136e6ae3ca 100644 --- a/examples/src/main/python/ml/min_hash_lsh_example.py +++ b/examples/src/main/python/ml/min_hash_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating MinHashLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating MinHashLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py index bb9cd82d6ba2..bec9860c79a2 100644 --- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -43,6 +43,44 @@ # Print the coefficients and intercept for multinomial logistic regression print("Coefficients: \n" + str(lrModel.coefficientMatrix)) print("Intercept: " + str(lrModel.interceptVector)) + + trainingSummary = lrModel.summary + + # Obtain the objective per iteration + objectiveHistory = trainingSummary.objectiveHistory + print("objectiveHistory:") + for objective in objectiveHistory: + print(objective) + + # for multiclass, we can inspect metrics on a per-label basis + print("False positive rate by label:") + for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("True positive rate by label:") + for i, rate in enumerate(trainingSummary.truePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("Precision by label:") + for i, prec in enumerate(trainingSummary.precisionByLabel): + print("label %d: %s" % (i, prec)) + + print("Recall by label:") + for i, rec in enumerate(trainingSummary.recallByLabel): + print("label %d: %s" % (i, rec)) + + print("F-measure by label:") + for i, f in enumerate(trainingSummary.fMeasureByLabel()): + print("label %d: %s" % (i, f)) + + accuracy = trainingSummary.accuracy + falsePositiveRate = trainingSummary.weightedFalsePositiveRate + truePositiveRate = trainingSummary.weightedTruePositiveRate + fMeasure = trainingSummary.weightedFMeasure() + precision = trainingSummary.weightedPrecision + recall = trainingSummary.weightedRecall + print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s" + % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall)) # $example off$ spark.stop() diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py index 8e00c25d9342..956e94ae4ab6 100644 --- a/examples/src/main/python/ml/one_vs_rest_example.py +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -15,6 +15,12 @@ # limitations under the License. # +""" +An example of Multiclass to Binary Reduction with One Vs Rest, +using Logistic Regression as the base classifier. +Run with: + bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py +""" from __future__ import print_function # $example on$ @@ -23,13 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example of Multiclass to Binary Reduction with One Vs Rest, -using Logistic Regression as the base classifier. -Run with: - bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_estimator_example.py similarity index 65% rename from examples/src/main/python/ml/onehot_encoder_example.py rename to examples/src/main/python/ml/onehot_encoder_estimator_example.py index e1996c7f0a55..2723e681cea7 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_estimator_example.py @@ -18,32 +18,31 @@ from __future__ import print_function # $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.ml.feature import OneHotEncoderEstimator # $example off$ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("OneHotEncoderExample")\ + .appName("OneHotEncoderEstimatorExample")\ .getOrCreate() + # Note: categorical features are usually first encoded with StringIndexer # $example on$ df = spark.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + ], ["categoryIndex1", "categoryIndex2"]) - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - - encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) + encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryVec1", "categoryVec2"]) + model = encoder.fit(df) + encoded = model.transform(df) encoded.show() # $example off$ diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py index d104f7d30a1b..d4f9184bf576 100644 --- a/examples/src/main/python/ml/train_validation_split.py +++ b/examples/src/main/python/ml/train_validation_split.py @@ -15,13 +15,6 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.ml.regression import LinearRegression -from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit -# $example off$ -from pyspark.sql import SparkSession - """ This example demonstrates applying TrainValidationSplit to split data and preform model selection. @@ -29,6 +22,12 @@ bin/spark-submit examples/src/main/python/ml/train_validation_split.py """ +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.regression import LinearRegression +from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/vector_size_hint_example.py b/examples/src/main/python/ml/vector_size_hint_example.py new file mode 100644 index 000000000000..fb77dacec629 --- /dev/null +++ b/examples/src/main/python/ml/vector_size_hint_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.linalg import Vectors +from pyspark.ml.feature import (VectorSizeHint, VectorAssembler) +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("VectorSizeHintExample")\ + .getOrCreate() + + # $example on$ + dataset = spark.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0), + (0, 18, 1.0, Vectors.dense([0.0, 10.0]), 0.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + + sizeHint = VectorSizeHint( + inputCol="userFeatures", + handleInvalid="skip", + size=3) + + datasetWithSize = sizeHint.transform(dataset) + print("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(truncate=False) + + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + + # This dataframe can be used by downstream transformers as before + output = assembler.transform(datasetWithSize) + print("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 0e13546b88e6..089504fa7064 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -31,7 +31,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: correlations ()", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: filepath = sys.argv[1] diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 002fc7579964..1bdb3e9b4a2a 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -36,7 +36,7 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index d4f1d34e2d8c..87efe1737522 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -42,7 +42,7 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 729bae30b152..9a429b5f8abd 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: random_rdd_generation", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index b7033ab7daeb..00e7cf4bbcdb 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: sampled_rdds ", file=sys.stderr) - exit(-1) + sys.exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] else: @@ -43,7 +43,7 @@ numExamples = examples.count() if numExamples == 0: print("Error: Data file had no samples to load.", file=sys.stderr) - exit(1) + sys.exit(1) print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py index f600496867c1..714c9a0de721 100644 --- a/examples/src/main/python/mllib/streaming_linear_regression_example.py +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -36,7 +36,7 @@ if len(sys.argv) != 3: print("Usage: streaming_linear_regression_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0d6c253d397a..2c19e8700ab1 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -47,7 +47,7 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: pagerank ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("WARN: This is a naive implementation of PageRank and is given as an example!\n" + "Please refer to PageRank implementation provided by graphx", diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 52e9662d528d..83041f0040a0 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -15,12 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from pyspark.sql import SparkSession - """ Read data file users.parquet in local Spark distro: @@ -35,6 +29,12 @@ {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} <...more log output...> """ +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2: print(""" @@ -45,7 +45,7 @@ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 81898cf6d5ce..d3cd985d197e 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -25,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py new file mode 100644 index 000000000000..4c5aefb6ff4a --- /dev/null +++ b/examples/src/main/python/sql/arrow.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A simple example demonstrating Arrow in Spark. +Run with: + ./bin/spark-submit examples/src/main/python/sql/arrow.py +""" + +from __future__ import print_function + +from pyspark.sql import SparkSession +from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version + +require_minimum_pandas_version() +require_minimum_pyarrow_version() + + +def dataframe_with_arrow_example(spark): + # $example on:dataframe_with_arrow$ + import numpy as np + import pandas as pd + + # Enable Arrow-based columnar data transfers + spark.conf.set("spark.sql.execution.arrow.enabled", "true") + + # Generate a Pandas DataFrame + pdf = pd.DataFrame(np.random.rand(100, 3)) + + # Create a Spark DataFrame from a Pandas DataFrame using Arrow + df = spark.createDataFrame(pdf) + + # Convert the Spark DataFrame back to a Pandas DataFrame using Arrow + result_pdf = df.select("*").toPandas() + # $example off:dataframe_with_arrow$ + print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) + + +def scalar_pandas_udf_example(spark): + # $example on:scalar_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import col, pandas_udf + from pyspark.sql.types import LongType + + # Declare the function and create the UDF + def multiply_func(a, b): + return a * b + + multiply = pandas_udf(multiply_func, returnType=LongType()) + + # The function for a pandas_udf should be able to execute with local Pandas data + x = pd.Series([1, 2, 3]) + print(multiply_func(x, x)) + # 0 1 + # 1 4 + # 2 9 + # dtype: int64 + + # Create a Spark DataFrame, 'spark' is an existing SparkSession + df = spark.createDataFrame(pd.DataFrame(x, columns=["x"])) + + # Execute function as a Spark vectorized UDF + df.select(multiply(col("x"), col("x"))).show() + # +-------------------+ + # |multiply_func(x, x)| + # +-------------------+ + # | 1| + # | 4| + # | 9| + # +-------------------+ + # $example off:scalar_pandas_udf$ + + +def grouped_map_pandas_udf_example(spark): + # $example on:grouped_map_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) + def substract_mean(pdf): + # pdf is a pandas.DataFrame + v = pdf.v + return pdf.assign(v=v - v.mean()) + + df.groupby("id").apply(substract_mean).show() + # +---+----+ + # | id| v| + # +---+----+ + # | 1|-0.5| + # | 1| 0.5| + # | 2|-3.0| + # | 2|-1.0| + # | 2| 4.0| + # +---+----+ + # $example off:grouped_map_pandas_udf$ + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("Python Arrow-in-Spark example") \ + .getOrCreate() + + print("Running Pandas to/from conversion example") + dataframe_with_arrow_example(spark) + print("Running pandas_udf scalar example") + scalar_pandas_udf_example(spark) + print("Running pandas_udf grouped map example") + grouped_map_pandas_udf_example(spark) + + spark.stop() diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index c07fa8f2752b..c8fb25d0533b 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating basic Spark SQL features. +Run with: + ./bin/spark-submit examples/src/main/python/sql/basic.py +""" from __future__ import print_function # $example on:init_session$ @@ -30,12 +35,6 @@ from pyspark.sql.types import * # $example off:programmatic_schema$ -""" -A simple example demonstrating basic Spark SQL features. -Run with: - ./bin/spark-submit examples/src/main/python/sql/basic.py -""" - def basic_df_example(spark): # $example on:create_df$ diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index b375fa775de3..d8c879dfe02e 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL data sources. +Run with: + ./bin/spark-submit examples/src/main/python/sql/datasource.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.sql import Row # $example off:schema_merging$ -""" -A simple example demonstrating Spark SQL data sources. -Run with: - ./bin/spark-submit examples/src/main/python/sql/datasource.py -""" - def basic_datasource_example(spark): # $example on:generic_load_save_functions$ diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f83a6fb48b9..33fc2dfbeefa 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL Hive integration. +Run with: + ./bin/spark-submit examples/src/main/python/sql/hive.py +""" from __future__ import print_function # $example on:spark_hive$ @@ -24,12 +29,6 @@ from pyspark.sql import Row # $example off:spark_hive$ -""" -A simple example demonstrating Spark SQL Hive integration. -Run with: - ./bin/spark-submit examples/src/main/python/sql/hive.py -""" - if __name__ == "__main__": # $example on:spark_hive$ diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py index 9e8a552b3b10..921067891352 100644 --- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -49,7 +49,7 @@ print(""" Usage: structured_kafka_wordcount.py """, file=sys.stderr) - exit(-1) + sys.exit(-1) bootstrapServers = sys.argv[1] subscribeType = sys.argv[2] diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index afde2550587c..9ac392164735 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network. Usage: structured_network_wordcount.py and describe the TCP server that Structured Streaming @@ -38,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: structured_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index 02a7d3363d78..c4e3bbf44cd5 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network over a sliding window of configurable duration. Each line from the network is tagged with a timestamp that is used to determine the windows into which it falls. @@ -53,7 +53,7 @@ msg = ("Usage: structured_network_wordcount_windowed.py " " []") print(msg, file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 7097f7f4502b..c5c186c11f79 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. Usage: direct_kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") ssc = StreamingContext(sc, 2) diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index d75bc6daac13..c8ea92b61ca6 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: flume_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: flume_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingFlumeWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f815dd26823d..f9a5c43a8eaa 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: hdfs_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 8d697f620f46..e9ee08b9fd22 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 2b48bcfd55db..f3099d2517cd 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: network_wordcount.py and describe the TCP server that Spark Streaming would connect to receive data. @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index b309d9fad33f..2b5434c0c845 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) @@ -47,7 +47,7 @@ def print_happiest_words(rdd): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") ssc = StreamingContext(sc, 5) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 52b2639cdf55..60167dc77254 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -101,7 +101,7 @@ def filterFunc(wordCount): if len(sys.argv) != 5: print("Usage: recoverable_network_wordcount.py " " ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, lambda: createContext(host, int(port), output)) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 398ac8d2d8f5..ab3cfc067994 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the network every second. @@ -48,7 +48,7 @@ def getSparkSessionInstance(sparkConf): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: sql_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index f8bbc659c2ea..d5d1eba6c596 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: stateful_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 3d5e44d5b2df..a05e24ff3ff9 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -26,7 +26,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d9..2332a661f26a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala new file mode 100644 index 000000000000..64076f2deb70 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import org.apache.spark.SparkFiles +import org.apache.spark.sql.SparkSession + +/** Usage: SparkRemoteFileTest [file] */ +object SparkRemoteFileTest { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: SparkRemoteFileTest ") + System.exit(1) + } + val spark = SparkSession + .builder() + .appName("SparkRemoteFileTest") + .getOrCreate() + val sc = spark.sparkContext + val rdd = sc.parallelize(Seq(1)).map(_ => { + val localLocation = SparkFiles.get(args(0)) + println(s"${args(0)} is stored at: $localLocation") + new File(localLocation).isFile + }) + val truthCheck = rdd.collect().head + println(s"Mounting of ${args(0)} was $truthCheck") + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc..815404d1218b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742a..57b2edf99220 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce5..5146fd031646 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb0..d7f1fc8ed74d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf1696..ee4469faab3a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933e..276cedab11ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0..aaaecaea4708 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e9..2dc11b07d88e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a692..e5d91f132a3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b81..ef78c0a1145e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d..3feb2343f6a8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 1740a0d3f9d1..0368dcba460b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +import org.apache.spark.ml.classification.LogisticRegression // $example off$ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.max @@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - val trainingSummary = lrModel.summary + val trainingSummary = lrModel.binarySummary // Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory println("objectiveHistory:") objectiveHistory.foreach(loss => println(loss)) - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a - // binary classification problem. - val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - val roc = binarySummary.roc + val roc = trainingSummary.roc roc.show() - println(s"areaUnderROC: ${binarySummary.areaUnderROC}") + println(s"areaUnderROC: ${trainingSummary.areaUnderROC}") // Set the model threshold to maximize F-Measure - val fMeasure = binarySummary.fMeasureByThreshold + val fMeasure = trainingSummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) .select("threshold").head().getDouble(0) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353..1f7dbddd454e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,50 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") + + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration + val objectiveHistory = trainingSummary.objectiveHistory + println("objectiveHistory:") + objectiveHistory.foreach(println) + + // for multiclass, we can inspect metrics on a per-label basis + println("False positive rate by label:") + trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("True positive rate by label:") + trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("Precision by label:") + trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) => + println(s"label $label: $prec") + } + + println("Recall by label:") + trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) => + println(s"label $label: $rec") + } + + + println("F-measure by label:") + trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) => + println(s"label $label: $f") + } + + val accuracy = trainingSummary.accuracy + val falsePositiveRate = trainingSummary.weightedFalsePositiveRate + val truePositiveRate = trainingSummary.weightedTruePositiveRate + val fMeasure = trainingSummary.weightedFMeasure + val precision = trainingSummary.weightedPrecision + val recall = trainingSummary.weightedRecall + println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" + + s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8..646f46a92506 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66..50c70c626b12 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala similarity index 65% rename from examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala index 274cc1268f4d..45d816808ed8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala @@ -19,38 +19,34 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +import org.apache.spark.ml.feature.OneHotEncoderEstimator // $example off$ import org.apache.spark.sql.SparkSession -object OneHotEncoderExample { +object OneHotEncoderEstimatorExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder - .appName("OneHotEncoderExample") + .appName("OneHotEncoderEstimatorExample") .getOrCreate() + // Note: categorical features are usually first encoded with StringIndexer // $example on$ val df = spark.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") - - val encoded = encoder.transform(indexed) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + )).toDF("categoryIndex1", "categoryIndex2") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryVec1", "categoryVec2")) + val model = encoder.fit(df) + + val encoded = model.transform(df) encoded.show() // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb7..0fe16fb6dfa9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce428..6265f8390252 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef..2679fcb353a8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b9..96bb8ea2338a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala new file mode 100644 index 000000000000..688731a791f3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{VectorAssembler, VectorSizeHint} +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +object VectorSizeHintExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("VectorSizeHintExample") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame( + Seq( + (0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0), + (0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3) + + val datasetWithSize = sizeHint.transform(dataset) + println("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(false) + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + // This dataframe can be used by downstream transformers as before + val output = assembler.transform(datasetWithSize) + println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(false) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e..a07535bb5a38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff..c6312d71cc91 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777c..c2f89b72c9a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5..1ecf6426e1f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc12775..f724ee1030f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a3..3c56e1941aec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839..c288bf29bf25 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c0..add171973953 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04..a10d6f0dda88 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f37..b0a6f1671a89 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa09838..123782fa6b9c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f..d25962c5500e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d39961809448..449b725d1d17 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba..eff2393cc3ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb1..96deafd469bc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c433657..8b789277774a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733e..246e71de2561 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4..770e30276bc3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e7..0bb2b8c8c2b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3fa..285e2ce51263 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba0..694c3bb18b04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef6069..3d41bef0af88 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5..071d341b8161 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e..8ae6de16d80e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b120..25c7bf287197 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index def06026bde9..2082fb71afdf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -18,6 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka010._ @@ -26,18 +29,20 @@ import org.apache.spark.streaming.kafka010._ * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: DirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ object DirectKafkaWordCount { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { System.err.println(s""" |Usage: DirectKafkaWordCount | is a list of one or more Kafka brokers + | is a consumer group name to consume from topics | is a list of one or more kafka topics to consume from | """.stripMargin) @@ -46,7 +51,7 @@ object DirectKafkaWordCount { StreamingExamples.setStreamingLogLevels() - val Array(brokers, topics) = args + val Array(brokers, groupId, topics) = args // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") @@ -54,7 +59,11 @@ object DirectKafkaWordCount { // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet - val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val kafkaParams = Map[String, Object]( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers, + ConsumerConfig.GROUP_ID_CONFIG -> groupId, + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer], + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer]) val messages = KafkaUtils.createDirectStream[String, String]( ssc, LocationStrategies.PreferConsistent, diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b..437ccf0898d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c042732113..f018f3a26d2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2..2108bc63edea 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32..b8e7c7e9e915 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 485b562dce99..431339d41219 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index 5a95a9387c31..c70cd7136767 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -15,14 +15,14 @@ # limitations under the License. # -FROM ubuntu:precise +FROM ubuntu:xenial # Upgrade package index -# install a few other useful packages plus Open Jdk 7 +# install a few other useful packages plus Open Jdk 8 # Remove unneeded /var/lib/apt/lists/* after install to reduce the # docker image size (by ~30MB) RUN apt-get update && \ - apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.11.8 diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 71016bc645ca..7cd1ec4c9c09 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 12630840e79d..f810aa80e878 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 87a09642405a..498e88f665eb 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index d6f97316b326..a742b8d6dbdd 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 0c9f0aa765a3..16bbc6db641c 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 000000000000..f26c134c2f6e --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.TimeoutException + +import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceOptions]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + // Initialized when creating reader factories. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setStartOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousDataReaderFactory( + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + .asInstanceOf[DataReaderFactory[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * into a full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousDataReaderFactory( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + require(kafkaOffset.topicPartition == topicPartition, + s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") + new KafkaContinuousDataReader( + topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } + + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader( + topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) + private val converter = new KafkaRecordToUnsafeRowConverter + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false + // Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving + // interrupt points to end the query rather than waiting for new data that might never come. + try { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs, + failOnDataLoss) + } catch { + // We didn't read within the timeout. We're supposed to block indefinitely for new data, so + // swallow and ignore this. + case _: TimeoutException => + + // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, + // or if it's the endpoint of the data range (i.e. the "true" next offset). + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + val range = consumer.getAvailableOffsetRange() + if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { + // retry + } else { + throw e + } + } + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + converter.toUnsafeRow(currentRecord) + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.release() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala similarity index 65% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 90ed7b1fba2f..48508d057a54 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -27,30 +27,73 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread +private[kafka010] sealed trait KafkaDataConsumer { + /** + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. + */ + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss) + } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange() + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + protected def internalConsumer: InternalKafkaConsumer +} + /** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected. + * This is not for direct use outside this file. */ -private[kafka010] case class CachedKafkaConsumer private( +private[kafka010] case class InternalKafkaConsumer( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { - import CachedKafkaConsumer._ + import InternalKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private var consumer = createConsumer + @volatile private var consumer = createConsumer /** indicates whether this consumer is in use or not */ - private var inuse = true + @volatile var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + @volatile var markedForClose = false /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = UNKNOWN_OFFSET + @volatile private var fetchedData = + ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -61,8 +104,6 @@ private[kafka010] case class CachedKafkaConsumer private( c } - case class AvailableOffsetRange(earliest: Long, latest: Long) - private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { case ut: UninterruptibleThread => ut.runUninterruptibly(body) @@ -313,21 +354,51 @@ private[kafka010] case class CachedKafkaConsumer private( } } -private[kafka010] object CachedKafkaConsumer extends Logging { - private val UNKNOWN_OFFSET = -2L +private[kafka010] object KafkaDataConsumer extends Logging { + + case class AvailableOffsetRange(earliest: Long, latest: Long) + + private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + assert(internalConsumer.inUse) // make sure this has been set to true + override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) } + } + + private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + override def release(): Unit = { internalConsumer.close() } + } - private case class CacheKey(groupId: String, topicPartition: TopicPartition) + private case class CacheKey(groupId: String, topicPartition: TopicPartition) { + def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = + this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) + } + // This cache has the following important properties. + // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. + // The capacity is not guaranteed to be maintained, especially when there are more active + // tasks simultaneously using consumers than the capacity. private lazy val cache = { val conf = SparkEnv.get.conf val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) - new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { - if (entry.getValue.inuse == false && this.size > capacity) { - logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + - s"removing consumer for ${entry.getKey}") + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > capacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") try { entry.getValue.close() } catch { @@ -342,80 +413,87 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } - def releaseKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - synchronized { - val consumer = cache.get(key) - if (consumer != null) { - consumer.inuse = false - } else { - logWarning(s"Attempting to release consumer that does not exist") - } - } - } - /** - * Removes (and closes) the Kafka Consumer for the given topic, partition and group id. + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by any one + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. */ - def removeKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) + def acquire( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + useCache: Boolean): KafkaDataConsumer = synchronized { + val key = new CacheKey(topicPartition, kafkaParams) + val existingInternalConsumer = cache.get(key) - synchronized { - val removedConsumer = cache.remove(key) - if (removedConsumer != null) { - removedConsumer.close() + lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams) + + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumer if any and + // start with a new one. + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + } } + cache.remove(key) // Invalidate the cache in any case + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (!useCache) { + // If planner asks to not reuse consumers, then do not use it, return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + cache.put(key, newInternalConsumer) + newInternalConsumer.inUse = true + CachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else { + // If consumer is already cached and is currently not in use, then return that consumer + existingInternalConsumer.inUse = true + CachedKafkaDataConsumer(existingInternalConsumer) } } - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def getOrCreate( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - // If this is reattempt at running the task, then invalidate cache and start with - // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { - removeKafkaConsumer(topic, partition, kafkaParams) - val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) - consumer.inuse = true - cache.put(key, consumer) - consumer - } else { - if (!cache.containsKey(key)) { - cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + private def release(intConsumer: InternalKafkaConsumer): Unit = { + synchronized { + + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams) + val cachedIntConsumer = cache.get(key) + if (intConsumer.eq(cachedIntConsumer)) { + // The released consumer is the same object as the cached one. + if (intConsumer.markedForClose) { + intConsumer.close() + cache.remove(key) + } else { + intConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + intConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache") } - val consumer = cache.get(key) - consumer.inuse = true - consumer } } +} - /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ - def createUncached( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { - new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) - } +private[kafka010] object InternalKafkaConsumer extends Logging { + + private val UNKNOWN_OFFSET = -2L private def reportDataLoss0( failOnDataLoss: Boolean, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala new file mode 100644 index 000000000000..cbe655f9bff1 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[MicroBatchReader]] that reads data from Kafka. + * + * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using Kafka maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] class KafkaMicroBatchReader( + kafkaOffsetReader: KafkaOffsetReader, + executorKafkaParams: ju.Map[String, Object], + options: DataSourceOptions, + metadataPath: String, + startingOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + + private var startPartitionOffsets: PartitionOffsetMap = _ + private var endPartitionOffsets: PartitionOffsetMap = _ + + private val pollTimeoutMs = options.getLong( + "kafkaConsumer.pollTimeoutMs", + SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + + private val maxOffsetsPerTrigger = + Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + + private val rangeCalculator = KafkaOffsetRangeCalculator(options) + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() + + override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + startPartitionOffsets = Option(start.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse(initialPartitionOffsets) + + endPartitionOffsets = Option(end.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse { + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + } + } + } + + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + // Find the new partitions, and get their earliest offsets + val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) + val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionInitialOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionInitialOffsets") + newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + // Find deleted partitions, and report data loss if required + val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the end partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = endPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + // Calculate offset ranges + val offsetRanges = rangeCalculator.getRanges( + fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, + untilOffsets = endPartitionOffsets, + executorLocations = getSortedExecutorList()) + + // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, + // that is, concurrent tasks will not read the same TopicPartitions. + val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size + + // Generate factories based on the offset ranges + val factories = offsetRanges.map { range => + new KafkaMicroBatchDataReaderFactory( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + } + factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + } + + override def getStartOffset: Offset = { + KafkaSourceOffset(startPartitionOffsets) + } + + override def getEndOffset: Offset = { + KafkaSourceOffset(endPartitionOffsets) + } + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = { + kafkaOffsetReader.close() + } + + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" + + /** + * Read initial partition offsets from the checkpoint, or decide the offsets and write them to + * the checkpoint. + */ + private def getOrCreateInitialPartitionOffsets(): PartitionOffsetMap = { + // Make sure that `KafkaConsumer.poll` is only called in StreamExecutionThread. + // Otherwise, interrupting a thread while running `KafkaConsumer.poll` may hang forever + // (KAFKA-1894). + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + + // SparkSession is required for getting Hadoop configuration for writing to checkpoints + assert(SparkSession.getActiveSession.nonEmpty) + + val metadataLog = + new KafkaSourceInitialOffsetWriter(SparkSession.getActiveSession.get, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = startingOffsets match { + case EarliestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => + kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: PartitionOffsetMap, + until: PartitionOffsetMap): PartitionOffsetMap = { + val fromNew = kafkaOffsetReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } + } + + private def getSortedExecutorList(): Array[String] = { + + def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + } + + val bm = SparkEnv.get.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } + + /** A version of [[HDFSMetadataLog]] specialized for saving the initial offsets. */ + class KafkaSourceInitialOffsetWriter(sparkSession: SparkSession, metadataPath: String) + extends HDFSMetadataLog[KafkaSourceOffset](sparkSession, metadataPath) { + + val VERSION = 1 + + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + in.read() // A zero byte is read to support Spark 2.1.0 (SPARK-19517) + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + // The log was generated by Spark 2.1.0 + KafkaSourceOffset(SerializedOffset(content)) + } + } + } +} + +/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchDataReaderFactory( + offsetRange: KafkaOffsetRange, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { + + override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray + + override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( + offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) +} + +/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchDataReader( + offsetRange: KafkaOffsetRange, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = KafkaDataConsumer.acquire( + offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + + private val rangeToRead = resolveRange(offsetRange) + private val converter = new KafkaRecordToUnsafeRowConverter + + private var nextOffset = rangeToRead.fromOffset + private var nextRow: UnsafeRow = _ + + override def next(): Boolean = { + if (nextOffset < rangeToRead.untilOffset) { + val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) + if (record != null) { + nextRow = converter.toUnsafeRow(record) + true + } else { + false + } + } else { + false + } + } + + override def get(): UnsafeRow = { + assert(nextRow != null) + nextOffset += 1 + nextRow + } + + override def close(): Unit = { + consumer.release() + } + + private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { + if (range.fromOffset < 0 || range.untilOffset < 0) { + // Late bind the offset range + val availableOffsetRange = consumer.getAvailableOffsetRange() + val fromOffset = if (range.fromOffset < 0) { + assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST, + s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}") + availableOffsetRange.earliest + } else { + range.fromOffset + } + val untilOffset = if (range.untilOffset < 0) { + assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST, + s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}") + availableOffsetRange.latest + } else { + range.untilOffset + } + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None) + } else { + range + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala new file mode 100644 index 000000000000..6631ae84167c --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.sources.v2.DataSourceOptions + + +/** + * Class to calculate offset ranges to process based on the the from and until offsets, and + * the configured `minPartitions`. + */ +private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { + require(minPartitions.isEmpty || minPartitions.get > 0) + + import KafkaOffsetRangeCalculator._ + /** + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` + * is not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If + * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up + * the read tasks of the skewed partitions to multiple Spark tasks. + * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more + * depending on rounding errors or Kafka partitions that didn't receive any new data. + */ + def getRanges( + fromOffsets: PartitionOffsetMap, + untilOffsets: PartitionOffsetMap, + executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = { + val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet) + + val offsetRanges = partitionsToRead.toSeq.map { tp => + KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None) + }.filter(_.size > 0) + + // If minPartitions not set or there are enough partitions to satisfy minPartitions + if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) { + // Assign preferred executor locations to each range such that the same topic-partition is + // preferentially read from the same executor and the KafkaConsumer can be reused. + offsetRanges.map { range => + range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) + } + } else { + + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = offsetRanges.map(_.size).sum + val idealRangeSize = totalSize.toDouble / minPartitions.get + + offsetRanges.flatMap { range => + // Split the current range into subranges as close to the ideal range size + val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt + + (0 until numSplitsInRange).map { i => + val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange) + val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange) + KafkaOffsetRange( + range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None) + } + } + } + } + + private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = { + def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b + + val numExecutors = executorLocations.length + if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(executorLocations(floorMod(tp.hashCode, numExecutors))) + } else None + } +} + +private[kafka010] object KafkaOffsetRangeCalculator { + + def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + new KafkaOffsetRangeCalculator(optionalValue) + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + lazy val size: Long = untilOffset - fromOffset +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1..82066697cb95 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -117,10 +125,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +157,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. @@ -287,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala new file mode 100644 index 000000000000..f35a143e0037 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ +private[kafka010] class KafkaRecordToUnsafeRowConverter { + private val rowWriter = new UnsafeRowWriter(7) + + def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { + rowWriter.reset() + + if (record.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, record.key) + } + rowWriter.write(1, record.value) + rowWriter.write(2, UTF8String.fromString(record.topic)) + rowWriter.write(3, record.partition) + rowWriter.write(4, record.offset) + rowWriter.write( + 5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) + rowWriter.write(6, record.timestampType.id) + rowWriter.getRow() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2..1c7b3a29a861 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -130,7 +131,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +139,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema @@ -223,6 +209,14 @@ private[kafka010] class KafkaSource( logInfo(s"GetBatch called with start = $start, end = $end") val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + if (start.isDefined && start.get == end) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) @@ -305,11 +299,6 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - // On recovery, getBatch will get called before getOffset - if (currentPartitionOffsets.isEmpty) { - currentPartitionOffsets = Some(untilPartitionOffsets) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } @@ -318,7 +307,7 @@ private[kafka010] class KafkaSource( kafkaReader.close() } - override def toString(): String = s"KafkaSource[$kafkaReader]" + override def toString(): String = s"KafkaSourceV1[$kafkaReader]" /** * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. @@ -335,22 +324,6 @@ private[kafka010] class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you want your streaming query to fail on such cases, set the source - | option "failOnDataLoss" to "true". - """.stripMargin - - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you don't want your streaming query to fail on such cases, set the - | source option "failOnDataLoss" to "false". - """.stripMargin - private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097..8d41c0da2b13 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12c..36b9f0466566 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,14 +27,16 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * The provider class for all Kafka readers and writers. It is designed such that it throws * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ @@ -43,6 +45,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with StreamWriteSupport + with ContinuousReadSupport + with MicroBatchReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,89 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches + * of Kafka data in a micro-batch streaming query. + */ + override def createMicroBatchReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceOptions): KafkaMicroBatchReader = { + + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Kafka data in a continuous streaming query. + */ + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceOptions): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +269,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + new KafkaStreamWriter(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -264,12 +348,18 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister throw new IllegalArgumentException("Unknown option") } + // Validate minPartitions value if present + if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { + val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt + if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") + } + // Validate user-specified Kafka options if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + - s"user-specified consumer groups is not used to track offsets.") + s"user-specified consumer groups are not used to track offsets.") } if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { @@ -297,7 +387,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "values are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + "operations to explicitly deserialize the values.") } @@ -371,8 +461,28 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + private val MIN_PARTITIONS_OPTION_KEY = "minpartitions" + val TOPIC_OPTION_KEY = "topic" + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + + + private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( @@ -450,4 +560,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 66b3409c0cd0..498e344ea39f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors @@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD( val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we - // uses `assign`, we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) - } + val consumer = KafkaDataConsumer.acquire( + sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD( } override protected def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) - } + consumer.release() } } // Release consumer, either by removing it or indicating we're no longer using it @@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD( } } - private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = { + private def resolveRange(consumer: KafkaDataConsumer, range: KafkaSourceRDDOffsetRange) = { if (range.fromOffset < 0 || range.untilOffset < 0) { // Late bind the offset range val availableOffsetRange = consumer.getAvailableOffsetRange() diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala new file mode 100644 index 000000000000..ae5b5c52d514 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.types.StructType + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaStreamWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends StreamWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + KafkaStreamWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaStreamWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { + new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaStreamDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43b..d90630a8adc9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -26,17 +26,15 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Unsa import org.apache.spark.sql.types.{BinaryType, StringType} /** - * A simple trait for writing out data in a single Spark task, without any concerns about how + * Writes out data in a single Spark task, without any concerns about how * to commit or abort tasks. Exceptions thrown by the implementation of this class will * automatically trigger task aborts. */ private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f00..15cd44812cb0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala similarity index 75% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala index 0372ad527095..43acd6a8d947 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -14,12 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer +package org.apache.spark.sql -/** - * Represents a step in configuring the driver init-container. - */ -private[spark] trait InitContainerConfigurationStep { +import org.apache.kafka.common.TopicPartition - def configureInitContainer(spec: InitContainerSpec): InitContainerSpec +package object kafka010 { // scalastyle:ignore + // ^^ scalastyle:ignore is for ignoring warnings about digits in package name + type PartitionOffsetMap = Map[TopicPartition, Long] } diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin new file mode 100644 index 000000000000..d530773f5732 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin @@ -0,0 +1,2 @@ +0v99999 +{"kafka-initial-offset-future-version":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin index ae928e724967..8c78d9e390a0 100644 --- a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin @@ -1 +1 @@ -2{"kafka-initial-offset-2-1-0":{"2":0,"1":0,"0":0}} \ No newline at end of file +2{"kafka-initial-offset-2-1-0":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 000000000000..ddfc0c1a4be2 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Option[Int], Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Option[Int], Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Option[Int], Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaStreamDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 000000000000..aab8ec42189f --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .trigger(Trigger.Continuous("1 second")) + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 000000000000..fa1468a3943c --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + // Continuous processing tasks end asynchronously, so test that they actually end. + private val tasksEndedListener = new SparkListener() { + val activeTaskIdCount = new AtomicInteger(0) + + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + activeTaskIdCount.incrementAndGet() + } + + override def onTaskEnd(end: SparkListenerTaskEnd): Unit = { + activeTaskIdCount.decrementAndGet() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + spark.sparkContext.addSparkListener(tasksEndedListener) + } + + override def afterEach(): Unit = { + eventually(timeout(streamingTimeout)) { + assert(tasksEndedListener.activeTaskIdCount.get() == 0) + } + spark.sparkContext.removeSparkListener(tasksEndedListener) + super.afterEach() + } + + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 000000000000..0d0fb9c3ab5a --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.PrivateMethodTester + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.ThreadUtils + +class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + protected var testUtils: KafkaTestUtils = _ + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils(Map[String, Object]()) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } + + test("SPARK-23623: concurrent use of KafkaDataConsumer") { + val topic = "topic" + Random.nextInt() + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic, 1) + testUtils.sendMessages(topic, data.toArray) + val topicPartition = new TopicPartition(topic, 0) + + import ConsumerConfig._ + val kafkaParams = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ) + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + TaskContext.setTaskContext(taskContext) + val consumer = KafkaDataConsumer.acquire( + topicPartition, kafkaParams.asJava, useCache) + try { + val range = consumer.getAvailableOffsetRange() + val rcvd = range.earliest until range.latest map { offset => + val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadpool.shutdown() + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala similarity index 72% rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 2034b9be07f2..d2d04b68de6a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,28 +20,34 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Properties} +import java.util.{Locale, Optional, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.io.Source import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata import org.apache.kafka.common.TopicPartition -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.util.Utils +import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -49,9 +55,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +67,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,10 +97,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - if (query.get.isActive) { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + query match { // Make sure no Spark job is running when deleting a topic - query.get.processAllAvailable() + case Some(m: MicroBatchExecution) => m.processAllAvailable() + case _ => } val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap @@ -97,16 +113,22 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + val sources = { + query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: KafkaSource, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader + } + }) + }.distinct + if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,187 +159,552 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) - testWithUninterruptibleThread( - "deserialization of initial offset with Spark 2.1.0") { - withTempDir { metadataPath => - val topic = newTopic - testUtils.createTopic(topic, partitions = 3) + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - source.getOffset.get // Write initial offset - - // Make sure Spark 2.1.0 will throw an exception when reading the new log - intercept[java.lang.IllegalArgumentException] { - // Simulate how Spark 2.1.0 reads the log - Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) } } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { + withTempDir { metadataPath => + val topic = "kafka-initial-offset-current" + testUtils.createTopic(topic, partitions = 1) + + val initialOffsetFile = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0").toFile + + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + + // Test the written initial offset file has 0 byte in the beginning, so that + // Spark 2.1.0 can read the offsets (see SPARK-19517) + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + makeSureGetOffsetCalled) + + val binarySource = Source.fromFile(initialOffsetFile) + try { + assert(binarySource.next().toInt == 0) // first byte is binary 0 + } finally { + binarySource.close() + } } } - testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") { + test("deserialization of initial offset written by Spark 2.1.0 (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-2-1-0" testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, Array("0", "1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("0", "10", "20"), Some(1)) + testUtils.sendMessages(topic, Array("0", "100", "200"), Some(2)) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. val from = new File( getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath - val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) Files.copy(from, to) - val source = provider.createSource( - spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) - val deserializedOffset = source.getOffset.get - val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - assert(referenceOffset == deserializedOffset) + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + // Test that the query starts from the expected initial offset (i.e. read older offsets, + // even though startingOffsets is latest). + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + AddKafkaData(Set(topic), 1000), + CheckAnswer(0, 1, 2, 10, 20, 200, 1000)) } } - testWithUninterruptibleThread("deserialization of initial offset written by future version") { + test("deserialization of initial offset written by future version") { withTempDir { metadataPath => - val futureMetadataLog = - new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, - metadataPath.getAbsolutePath) { - override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { - out.write(0) - val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v99999\n${metadata.json}") - writer.flush - } - } - - val topic = newTopic + val topic = "kafka-initial-offset-future-version" testUtils.createTopic(topic, partitions = 3) - val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - futureMetadataLog.add(0, offset) - - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - val e = intercept[java.lang.IllegalStateException] { - source.getOffset.get // Read initial offset - } + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. + val from = new File( + getClass.getResource("/kafka-source-initial-offset-future-version.bin").toURI).toPath + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) + Files.copy(from, to) - Seq( - s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", - "produced by a newer version of Spark and cannot be read by this version" - ).foreach { message => - assert(e.getMessage.contains(message)) - } + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + ExpectFailure[IllegalStateException](e => { + Seq( + s"maximum supported log version is v1, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.toString.contains(message)) + } + })) } } - test("(de)serialization of initial offsets") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") .option("subscribe", topic) + .load() - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } - test("maxOffsetsPerTrigger") { + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true + } + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) + } + + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) + } + + test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { + def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, range.map(_.toString).toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + + reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(k => k.toInt) + } + + val df1 = getSpecificDF(0 to 9) + val df2 = getSpecificDF(100 to 199) + + val kafka = df1.union(df2) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(kafka)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((0 to 4) ++ (100 to 104): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((5 to 9) ++ (105 to 109): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smaller topic empty, 5 from bigger one + CheckLastBatch(110 to 114: _*), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(115 to 119: _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(120 to 124: _*) + ) + } + + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .load() + + val values = kafka + .selectExpr("CAST(CAST(value AS STRING) AS INT) AS value", + "CAST(CAST(value AS STRING) AS INT) % 5 AS key") + + val join = values.join(values, "key") + + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } + ) + } +} + + +class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set( + "spark.sql.streaming.disabledV2MicroBatchReaders", + classOf[KafkaSourceProvider].getCanonicalName) + } + + test("V1 Source is used when disabled through SQLConf") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaSource, _) => true + }.nonEmpty + } + ) + } +} + +class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { + + test("V2 Source is used by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() - val clock = new StreamManualClock + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + }.nonEmpty + } + ) + } - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + testWithUninterruptibleThread("minPartitions is supported") { + import testImplicits._ + + val topic = newTopic() + val tp = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + def test( + minPartitions: String, + numPartitionsGenerated: Int, + reusesConsumers: Boolean): Unit = { + + SparkSession.setActiveSession(spark) + withTempDir { dir => + val provider = new KafkaSourceProvider() + val options = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "subscribe" -> topic + ) ++ Option(minPartitions).map { p => "minPartitions" -> p} + val reader = provider.createMicroBatchReader( + Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + reader.setOffsetRange( + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) + ) + val factories = reader.createUnsafeRowReaderFactories().asScala + .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { + assert(factories.size == numPartitionsGenerated) + factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } - if (q.exception.isDefined) { - throw q.exception.get - } - true } - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + // Test cases when minPartitions is used and not used + test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false) + + // Test illegal minPartitions values + intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } } +} + +abstract class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5) @@ -328,7 +715,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -422,65 +809,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -540,77 +868,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("delete a topic when a Spark job is running") { - KafkaSourceSuite.collectedData.clear() - - val topic = newTopic() - testUtils.createTopic(topic, partitions = 1) - testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribe", topic) - // If a topic is deleted and we try to poll data starting from offset 0, - // the Kafka consumer will just block until timeout and return an empty result. - // So set the timeout to 1 second to make this test fast. - .option("kafkaConsumer.pollTimeoutMs", "1000") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - KafkaSourceSuite.globalTestUtils = testUtils - // The following ForeachWriter will delete the topic before fetching data from Kafka - // in executors. - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { - override def open(partitionId: Long, version: Long): Boolean = { - KafkaSourceSuite.globalTestUtils.deleteTopic(topic) - true - } - - override def process(value: Int): Unit = { - KafkaSourceSuite.collectedData.add(value) - } - - override def close(errorOrNull: Throwable): Unit = {} - }).start() - query.processAllAvailable() - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - assert(query.exception.isEmpty) - } - test("get offsets from case insensitive parameters") { for ((optionKey, optionValue, answer) <- Seq( (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), @@ -629,8 +886,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -676,6 +931,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -704,13 +963,14 @@ class KafkaSourceSuite extends KafkaSourceTest { val query = kafka .writeStream .format("memory") - .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() - query.processAllAvailable() - val rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) + eventually(timeout(streamingTimeout)) { + assert(spark.table("kafkaColumnTypes").count == 1, + s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}") + } + val row = spark.table("kafkaColumnTypes").head() assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") @@ -723,47 +983,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -800,9 +1019,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -843,9 +1060,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -977,20 +1192,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1004,6 +1207,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala new file mode 100644 index 000000000000..2ccf3e291bea --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.DataSourceOptions + +class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { + + def testWithMinPartitions(name: String, minPartition: Int) + (f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + test(s"with minPartition = $minPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + + test("with no minPartition: N TopicPartitions to N offset ranges") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2), + executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 2, Some("location")))) + } + + test("with no minPartition: empty ranges ignored") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + } + + testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp2, 1, 2, None), + KafkaOffsetRange(tp3, 1, 2, None))) + } + + testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5), + executorLocations = Seq("location")) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set + } + + testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 11)) == + Seq( + KafkaOffsetRange(tp1, 1, 4, None), + KafkaOffsetRange(tp1, 4, 7, None), + KafkaOffsetRange(tp1, 7, 11, None))) + } + + testWithMinPartitions("empty ranges ignored", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + private val tp1 = new TopicPartition("t1", 1) + private val tp2 = new TopicPartition("t2", 1) + private val tp3 = new TopicPartition("t3", 1) +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 2ab336c7ac47..7079ac6453ff 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -336,27 +336,31 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { val input = MemoryStream[String] var writer: StreamingQuery = null var ex: Exception = null - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 6eb7ba5f0092..3b124b2a69d5 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala index fa3ea6131a50..aeb8c1dc342b 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -22,10 +22,8 @@ import java.{ util => ju } import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } import org.apache.kafka.common.{ KafkaException, TopicPartition } -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging - /** * Consumer of single topicpartition, intended for cached reuse. * Underlying consumer is not threadsafe, so neither is this, @@ -38,7 +36,7 @@ class CachedKafkaConsumer[K, V] private( val partition: Int, val kafkaParams: ju.Map[String, Object]) extends Logging { - assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), + require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), "groupId used for cache key must match the groupId in kafkaParams") val topicPartition = new TopicPartition(topic, partition) @@ -53,7 +51,7 @@ class CachedKafkaConsumer[K, V] private( // TODO if the buffer was kept around as a random-access structure, // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator + protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() protected var nextOffset = -2L def close(): Unit = consumer.close() @@ -71,7 +69,7 @@ class CachedKafkaConsumer[K, V] private( } if (!buffer.hasNext()) { poll(timeout) } - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") var record = buffer.next() @@ -79,17 +77,56 @@ class CachedKafkaConsumer[K, V] private( logInfo(s"Buffer miss for $groupId $topic $partition $offset") seek(offset) poll(timeout) - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") record = buffer.next() - assert(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset") + require(record.offset == offset, + s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) } nextOffset = offset + 1 record } + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, timeout: Long): Unit = { + logDebug(s"compacted start $groupId $topic $partition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(timeout: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + private def seek(offset: Long): Unit = { logDebug(s"Seeking to $topicPartition $offset") consumer.seek(topicPartition, offset) @@ -99,7 +136,7 @@ class CachedKafkaConsumer[K, V] private( val p = consumer.poll(timeout) val r = p.records(topicPartition) logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.iterator + buffer = r.listIterator } } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0fa3287f36db..c3221481556f 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -56,6 +56,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + val executorKafkaParams = { val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams) KafkaUtils.fixKafkaParams(ekp) @@ -126,7 +129,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { @@ -138,17 +144,17 @@ private[spark] class DirectKafkaInputDStream[K, V]( lagPerPartition.map { case (tp, lag) => val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None @@ -184,8 +190,20 @@ private[spark] class DirectKafkaInputDStream[K, V]( // make sure new partitions are reflected in currentOffsets val newPartitions = parts.diff(currentOffsets.keySet) + + // Check if there's any partition been revoked because of consumer rebalance. + val revokedPartitions = currentOffsets.keySet.diff(parts) + if (revokedPartitions.nonEmpty) { + throw new IllegalStateException(s"Previously tracked partitions " + + s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " + + s"rebalance. This is mostly due to another stream with same group id joined, " + + s"please check if there're different streaming application misconfigure to use same " + + s"group id. Fundamentally different stream should use different group id") + } + // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap + // don't want to consume messages, so pause c.pause(newPartitions.asJava) // find latest available offsets diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index d9fc9cc20664..07239eda64d2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -55,12 +55,12 @@ private[spark] class KafkaRDD[K, V]( useConsumerCache: Boolean ) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges { - assert("none" == + require("none" == kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String], ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " must be set to none for executor kafka params, else messages may not match offsetRange") - assert(false == + require(false == kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean], ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + " must be set to false for executor kafka params, else offsets may commit before processing") @@ -74,6 +74,8 @@ private[spark] class KafkaRDD[K, V]( conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64) private val cacheLoadFactor = conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat + private val compacted = + conf.getBoolean("spark.streaming.kafka.allowNonConsecutiveOffsets", false) override def persist(newLevel: StorageLevel): this.type = { logError("Kafka ConsumerRecord is not serializable. " + @@ -87,48 +89,63 @@ private[spark] class KafkaRDD[K, V]( }.toArray } - override def count(): Long = offsetRanges.map(_.count).sum + override def count(): Long = + if (compacted) { + super.count() + } else { + offsetRanges.map(_.count).sum + } override def countApprox( timeout: Long, confidence: Double = 0.95 - ): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[K, V]] = { - val nonEmptyPartitions = this.partitions - .map(_.asInstanceOf[KafkaRDDPartition]) - .filter(_.count > 0) + ): PartialResult[BoundedDouble] = + if (compacted) { + super.countApprox(timeout, confidence) + } else { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[K, V]](0) + override def isEmpty(): Boolean = + if (compacted) { + super.isEmpty() + } else { + count == 0L } - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.count) - result + (part.index -> taken.toInt) + override def take(num: Int): Array[ConsumerRecord[K, V]] = + if (compacted) { + super.take(num) + } else if (num < 1) { + Array.empty[ConsumerRecord[K, V]] + } else { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (nonEmptyPartitions.isEmpty) { + Array.empty[ConsumerRecord[K, V]] } else { - result + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ).flatten } } - val buf = new ArrayBuffer[ConsumerRecord[K, V]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - private def executors(): Array[ExecutorCacheTaskLocation] = { val bm = sparkContext.env.blockManager bm.master.getPeers(bm.blockManagerId).toArray @@ -172,57 +189,138 @@ private[spark] class KafkaRDD[K, V]( override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = { val part = thePart.asInstanceOf[KafkaRDDPartition] - assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { - new KafkaRDDIterator(part, context) + logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + if (compacted) { + new CompactedKafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } else { + new KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } } } +} - /** - * An iterator that fetches messages directly from Kafka for the offsets in partition. - * Uses a cached consumer where possible to take advantage of prefetching - */ - private class KafkaRDDIterator( - part: KafkaRDDPartition, - context: TaskContext) extends Iterator[ConsumerRecord[K, V]] { - - logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + - s"offsets ${part.fromOffset} -> ${part.untilOffset}") - - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching + */ +private class KafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float +) extends Iterator[ConsumerRecord[K, V]] { + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + context.addTaskCompletionListener(_ => closeIfNeeded()) + + val consumer = if (useConsumerCache) { + CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + if (context.attemptNumber >= 1) { + // just in case the prior attempt failures were cache related + CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + } + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + } else { + CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + } - context.addTaskCompletionListener{ context => closeIfNeeded() } + var requestOffset = part.fromOffset - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + def closeIfNeeded(): Unit = { + if (!useConsumerCache && consumer != null) { + consumer.close() } + } - var requestOffset = part.fromOffset + override def hasNext(): Boolean = requestOffset < part.untilOffset - def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close - } + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") } + val r = consumer.get(requestOffset, pollTimeout) + requestOffset += 1 + r + } +} - override def hasNext(): Boolean = requestOffset < part.untilOffset - - override def next(): ConsumerRecord[K, V] = { - assert(hasNext(), "Can't call getNext() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeout) - requestOffset += 1 - r +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching. + * Intended for compacted topics, or other cases when non-consecutive offsets are ok. + */ +private class CompactedKafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float + ) extends KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) { + + consumer.compactedStart(part.fromOffset, pollTimeout) + + private var nextRecord = consumer.compactedNext(pollTimeout) + + private var okNext: Boolean = true + + override def hasNext(): Boolean = okNext + + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") + } + val r = nextRecord + if (r.offset + 1 >= part.untilOffset) { + okNext = false + } else { + nextRecord = consumer.compactedNext(pollTimeout) + if (nextRecord.offset >= part.untilOffset) { + okNext = false + consumer.compactedPrevious() + } } + r } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 453b5e5ab20d..35e4678f2e3c 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010 import java.io.File import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap } +import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -34,7 +34,7 @@ import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} @@ -617,6 +617,101 @@ class DirectKafkaStreamSuite ssc.stop() } + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) + } + kafkaStream.start() + + val input = Map(new TopicPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> maxMessagesPerPartition)) // we run for half a second + + kafkaStream.stop() + } + + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = getKafkaParams() + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimateRate = 1L + val fromOffsets = Map( + new TopicPartition(topic, 0) -> 0L, + new TopicPartition(topic, 1) -> 0L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 0L + ) + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { + currentOffsets = fromOffsets + override val rateController = Some(new ConstantRateController(id, null, estimateRate)) + } + } + + val offsets = Map[TopicPartition, Long]( + new TopicPartition(topic, 0) -> 0, + new TopicPartition(topic, 1) -> 100L, + new TopicPartition(topic, 2) -> 200L, + new TopicPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 1) -> 10L, + new TopicPartition(topic, 2) -> 20L, + new TopicPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index be373af0599c..271adea1df73 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -18,16 +18,22 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } +import java.io.File import scala.collection.JavaConverters._ import scala.util.Random +import kafka.common.TopicAndPartition +import kafka.log._ +import kafka.message._ +import kafka.utils.Pool import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.streaming.kafka010.mocks.MockTime class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -64,6 +70,41 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private val preferredHosts = LocationStrategies.PreferConsistent + private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { + val mockTime = new MockTime() + // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api + val logs = new Pool[TopicAndPartition, Log]() + val logDir = kafkaTestUtils.brokerLogDir + val dir = new File(logDir, topic + "-" + partition) + dir.mkdirs() + val logProps = new ju.Properties() + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val log = new Log( + dir, + LogConfig(logProps), + 0L, + mockTime.scheduler, + mockTime + ) + messages.foreach { case (k, v) => + val msg = new ByteBufferMessageSet( + NoCompressionCodec, + new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) + log.append(msg) + } + log.roll() + logs.put(TopicAndPartition(topic, partition), log) + + val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + cleaner.startup() + cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + + cleaner.shutdown() + mockTime.scheduler.shutdown() + } + + test("basic usage") { val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" kafkaTestUtils.createTopic(topic) @@ -102,6 +143,71 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("compacted topic") { + val compactConf = sparkConf.clone() + compactConf.set("spark.streaming.kafka.allowNonConsecutiveOffsets", "true") + sc.stop() + sc = new SparkContext(compactConf) + val topic = s"topiccompacted-${Random.nextInt}-${System.currentTimeMillis}" + + val messages = Array( + ("a", "1"), + ("a", "2"), + ("b", "1"), + ("c", "1"), + ("c", "2"), + ("b", "2"), + ("b", "3") + ) + val compactedMessages = Array( + ("a", "2"), + ("b", "3"), + ("c", "2") + ) + + compactLogs(topic, 0, messages) + + val props = new ju.Properties() + props.put("cleanup.policy", "compact") + props.put("flush.messages", "1") + props.put("segment.ms", "1") + props.put("segment.bytes", "256") + kafkaTestUtils.createTopic(topic, 1, props) + + + val kafkaParams = getKafkaParams() + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, offsetRanges, preferredHosts + ).map(m => m.key -> m.value) + + val received = rdd.collect.toSet + assert(received === compactedMessages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === compactedMessages.size) + assert(rdd.countApprox(0).getFinalValue.mean === compactedMessages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head === compactedMessages.head) + assert(rdd.take(messages.size + 10).size === compactedMessages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts) + .map(_.value) + .collect() + } + } + test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 6c7024ea4b5a..70b579d96d69 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -162,17 +162,22 @@ private[kafka010] class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, config: Properties): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1, config) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + createTopic(topic, partitions, new Properties()) + } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { - createTopic(topic, 1) + createTopic(topic, 1, new Properties()) } /** Java-friendly function for sending messages to the Kafka broker */ @@ -196,12 +201,24 @@ private[kafka010] class KafkaTestUtils extends Logging { producer = null } + /** Send the array of (key, value) messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[(String, String)]): Unit = { + producer = new KafkaProducer[String, String](producerConfiguration) + messages.foreach { message => + producer.send(new ProducerRecord[String, String](topic, message._1, message._2)) + } + producer.close() + producer = null + } + + val brokerLogDir = Utils.createTempDir().getAbsolutePath + private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala new file mode 100644 index 000000000000..928e1a6ef54b --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.PriorityQueue + +import kafka.utils.{Scheduler, Time} + +/** + * A mock scheduler that executes tasks synchronously using a mock time instance. + * Tasks are executed synchronously when the time is advanced. + * This class is meant to be used in conjunction with MockTime. + * + * Example usage + * + * val time = new MockTime + * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000) + * time.sleep(1001) // this should cause our scheduled task to fire + * + * + * Incrementing the time to the exact next execution time of a task will result in that task + * executing (it as if execution itself takes no time). + */ +private[kafka010] class MockScheduler(val time: Time) extends Scheduler { + + /* a priority queue of tasks ordered by next execution time */ + var tasks = new PriorityQueue[MockTask]() + + def isStarted: Boolean = true + + def startup(): Unit = {} + + def shutdown(): Unit = synchronized { + tasks.foreach(_.fun()) + tasks.clear() + } + + /** + * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs + * when this method is called and the execution happens synchronously in the calling thread. + * If you are using the scheduler associated with a MockTime instance this call + * will be triggered automatically. + */ + def tick(): Unit = synchronized { + val now = time.milliseconds + while(!tasks.isEmpty && tasks.head.nextExecution <= now) { + /* pop and execute the task with the lowest next execution time */ + val curr = tasks.dequeue + curr.fun() + /* if the task is periodic, reschedule it and re-enqueue */ + if(curr.periodic) { + curr.nextExecution += curr.period + this.tasks += curr + } + } + } + + def schedule( + name: String, + fun: () => Unit, + delay: Long = 0, + period: Long = -1, + unit: TimeUnit = TimeUnit.MILLISECONDS): Unit = synchronized { + tasks += MockTask(name, fun, time.milliseconds + delay, period = period) + tick() + } + +} + +case class MockTask( + val name: String, + val fun: () => Unit, + var nextExecution: Long, + val period: Long) extends Ordered[MockTask] { + def periodic: Boolean = period >= 0 + def compare(t: MockTask): Int = { + java.lang.Long.compare(t.nextExecution, nextExecution) + } +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala new file mode 100644 index 000000000000..a68f94db1f68 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent._ + +import kafka.utils.Time + +/** + * A class used for unit testing things which depend on the Time interface. + * + * This class never manually advances the clock, it only does so when you call + * sleep(ms) + * + * It also comes with an associated scheduler instance for managing background tasks in + * a deterministic way. + */ +private[kafka010] class MockTime(@volatile private var currentMs: Long) extends Time { + + val scheduler = new MockScheduler(this) + + def this() = this(System.currentTimeMillis) + + def milliseconds: Long = currentMs + + def nanoseconds: Long = + TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) + + def sleep(ms: Long) { + this.currentMs += ms + scheduler.tick() + } + + override def toString(): String = s"MockTime($milliseconds)" + +} diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 786349474389..41bc8b3e3ee1 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 849c8b465f99..6d1c4789f382 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d52c230eb784..9297c39d170c 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -91,9 +91,16 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { @@ -104,17 +111,17 @@ class DirectKafkaInputDStream[ val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a..791cf0efaf88 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 06ef5bc3f8bd..ecca38784e77 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.File -import java.util.Arrays +import java.util.{ Arrays, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -32,12 +32,11 @@ import kafka.serializer.StringDecoder import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils @@ -456,6 +455,111 @@ class DirectKafkaStreamSuite ssc.stop() } + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val topicPartitions = Set(TopicAndPartition(topic, 0)) + kafkaTestUtils.createTopic(topic, 1) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(topicPartitions) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) + } + kafkaStream.start() + + val input = Map(new TopicAndPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicAndPartition(topic, 0) -> maxMessagesPerPartition)) + + kafkaStream.stop() + } + + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimatedRate = 1L + val kafkaStream = withClue("Error creating direct stream") { + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val fromOffsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 0L, + TopicAndPartition(topic, 2) -> 0L, + TopicAndPartition(topic, 3) -> 0L + ) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, fromOffsets, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, null) { + override def getLatestRate() = estimatedRate + }) + } + } + + val offsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 100L, + TopicAndPartition(topic, 2) -> 200L, + TopicAndPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + TopicAndPartition(topic, 0) -> 1L, + TopicAndPartition(topic, 1) -> 10L, + TopicAndPartition(topic, 2) -> 20L, + TopicAndPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 48783d65826a..37c7d1e604ec 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 40a751a652fa..491589396559 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index 4d7fc9a549bf..49794faab88c 100644 --- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -34,7 +34,7 @@ $ export AWS_SECRET_KEY= # run the example - $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\ + $ bin/spark-submit -jars external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 36d555066b18..027157e53d51 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index cb30e4a4af4b..fbe77fcb958d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index fd7b7f7c1c48..ebd65e8320e5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -303,7 +303,7 @@ object PageRank extends Logging { val src: VertexId = srcId.getOrElse(-1L) // Initialize the pagerankGraph with each edge attribute - // having weight 1/outDegree and each vertex with attribute 1.0. + // having weight 1/outDegree and each vertex with attribute 0. val pagerankGraph: Graph[(Double, Double), Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index aa36dd4774d8..2c39a7df0146 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/launcher/pom.xml b/launcher/pom.xml index e9b46c4cf0ff..912eb6b6d2a0 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d..9cbebdaeb33d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -18,39 +18,40 @@ package org.apache.spark.launcher; import java.io.IOException; -import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; abstract class AbstractAppHandle implements SparkAppHandle { - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(AbstractAppHandle.class.getName()); private final LauncherServer server; - private LauncherConnection connection; + private LauncherServer.ServerConnection connection; private List listeners; - private State state; - private String appId; - private boolean disposed; + private AtomicReference state; + private volatile String appId; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; - this.state = State.UNKNOWN; + this.state = new AtomicReference<>(State.UNKNOWN); } @Override public synchronized void addListener(Listener l) { if (listeners == null) { - listeners = new ArrayList<>(); + listeners = new CopyOnWriteArrayList<>(); } listeners.add(l); } @Override public State getState() { - return state; + return state.get(); } @Override @@ -70,20 +71,17 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; - if (connection != null) { - try { - connection.close(); - } catch (IOException ioe) { - // no-op. - } + if (connection != null && connection.isOpen()) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. } - server.unregister(this); } + dispose(); } - void setConnection(LauncherConnection connection) { + void setConnection(LauncherServer.ServerConnection connection) { this.connection = connection; } @@ -95,21 +93,60 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + * + * This method should be called only when there's a reasonable expectation that the communication + * with the child application is not needed anymore, either because the code managing the handle + * has said so, or because the child application is finished. + */ + synchronized void dispose() { + if (!isDisposed()) { + // First wait for all data from the connection to be read. Then unregister the handle. + // Otherwise, unregistering might cause the server to be stopped and all child connections + // to be closed. + if (connection != null) { + try { + connection.waitForClose(); + } catch (IOException ioe) { + // no-op. + } + } + server.unregister(this); + + // Set state to LOST if not yet final. + setState(State.LOST, false); + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } - synchronized void setState(State s, boolean force) { - if (force || !state.isFinal()) { - state = s; + void setState(State s, boolean force) { + if (force) { + state.set(s); fireEvent(false); - } else { + return; + } + + State current = state.get(); + while (!current.isFinal()) { + if (state.compareAndSet(current, s)) { + fireEvent(false); + return; + } + current = state.get(); + } + + if (s != State.LOST) { LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { state, s }); + new Object[] { current, s }); } } - synchronized void setAppId(String appId) { + void setAppId(String appId) { this.appId = appId; fireEvent(true); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 44e69fc45dff..4e02843480e8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -139,7 +139,7 @@ public T setMainClass(String mainClass) { public T addSparkArg(String arg) { SparkSubmitOptionParser validator = new ArgumentValidator(false); validator.parse(Arrays.asList(arg)); - builder.sparkArgs.add(arg); + builder.userArgs.add(arg); return self(); } @@ -187,8 +187,8 @@ public T addSparkArg(String name, String value) { } } else { validator.parse(Arrays.asList(name, value)); - builder.sparkArgs.add(name); - builder.sparkArgs.add(value); + builder.userArgs.add(name); + builder.userArgs.add(value); } return self(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750..5609f8492f4f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -104,20 +104,15 @@ void monitorChild() { ec = 1; } - State currState = getState(); - State newState = null; if (ec != 0) { + State currState = getState(); // Override state with failure if the current state is not final, or is success. if (!currState.isFinal() || currState == State.FINISHED) { - newState = State.FAILED; + setState(State.FAILED, true); } - } else if (!currState.isFinal()) { - newState = State.LOST; } - if (newState != null) { - setState(newState, true); - } + dispose(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 0d6a73a3da3e..15fbca0facef 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.io.IOException; import java.lang.reflect.Method; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; @@ -26,7 +25,7 @@ class InProcessAppHandle extends AbstractAppHandle { private static final String THREAD_NAME_FMT = "spark-app-%d: '%s'"; - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(InProcessAppHandle.class.getName()); private static final AtomicLong THREAD_IDS = new AtomicLong(); // Avoid really long thread names. @@ -40,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { @@ -66,14 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - synchronized (InProcessAppHandle.this) { - if (!isDisposed()) { - disconnect(); - if (!getState().isFinal()) { - setState(State.LOST, true); - } - } - } + dispose(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java index 6d726b4a69a8..688e1f763c20 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java @@ -89,10 +89,18 @@ Method findSparkSubmit() throws IOException { } Class sparkSubmit; + // SPARK-22941: first try the new SparkSubmit interface that has better error handling, + // but fall back to the old interface in case someone is mixing & matching launcher and + // Spark versions. try { - sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); - } catch (Exception e) { - throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", e); + sparkSubmit = cl.loadClass("org.apache.spark.deploy.InProcessSparkSubmit"); + } catch (Exception e1) { + try { + sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); + } catch (Exception e2) { + throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", + e2); + } } Method main; diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e2605..e8ab3f5e369a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -53,7 +53,7 @@ abstract class LauncherConnection implements Closeable, Runnable { public void run() { try { FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream()); - while (!closed) { + while (isOpen()) { Message msg = (Message) in.readObject(); handle(msg); } @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { - if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + public synchronized void close() throws IOException { + if (isOpen()) { + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f..607879fd02ea 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,7 @@ void unregister(AbstractAppHandle handle) { break; } } + unref(); } @@ -237,6 +238,7 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); + clientConnection.setConnectionThread(clientThread); synchronized (clients) { clients.add(clientConnection); } @@ -285,16 +287,21 @@ private String createSecret() { } } - private class ServerConnection extends LauncherConnection { + class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + private volatile Thread connectionThread; + private volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); this.timeout = timeout; } + void setConnectionThread(Thread t) { + this.connectionThread = t; + } + @Override protected void handle(Message msg) throws IOException { try { @@ -313,7 +320,7 @@ protected void handle(Message msg) throws IOException { } else { if (handle == null) { throw new IllegalArgumentException("Expected hello, got: " + - msg != null ? msg.getClass().getName() : null); + msg != null ? msg.getClass().getName() : null); } if (msg instanceof SetAppId) { SetAppId set = (SetAppId) msg; @@ -331,6 +338,9 @@ protected void handle(Message msg) throws IOException { timeout.cancel(); } close(); + if (handle != null) { + handle.dispose(); + } } finally { timeoutTimer.purge(); } @@ -338,16 +348,42 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } + super.close(); - if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); + } + + /** + * Wait for the remote side to close the connection so that any pending data is processed. + * This ensures any changes reported by the child application take effect. + * + * This method allows a short period for the above to happen (same amount of time as the + * connection timeout, which is configurable). This should be fine for well-behaved + * applications, where they close the connection arond the same time the app handle detects the + * app has finished. + * + * In case the connection is not closed within the grace period, this method forcefully closes + * it and any subsequent data that may arrive will be ignored. + */ + public void waitForClose() throws IOException { + Thread connThread = this.connectionThread; + if (Thread.currentThread() != connThread) { + try { + connThread.join(getConnectionTimeout()); + } catch (InterruptedException ie) { + // Ignore. + } + + if (connThread.isAlive()) { + LOG.log(Level.WARNING, "Timed out waiting for child connection to close."); + close(); } - handle.disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index e0ef22d7d505..5cb6457bf5c2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -88,8 +88,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkLauncher.NO_RESOURCE); } - final List sparkArgs; - private final boolean isAppResourceReq; + final List userArgs; + private final List parsedArgs; + private final boolean requiresAppResource; private final boolean isExample; /** @@ -99,17 +100,27 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ private boolean allowsMixedArguments; + /** + * This constructor is used when creating a user-configurable launcher. It allows the + * spark-submit argument list to be modified after creation. + */ SparkSubmitCommandBuilder() { - this.sparkArgs = new ArrayList<>(); - this.isAppResourceReq = true; + this.requiresAppResource = true; this.isExample = false; + this.parsedArgs = new ArrayList<>(); + this.userArgs = new ArrayList<>(); } + /** + * This constructor is used when invoking spark-submit; it parses and validates arguments + * provided by the user on the command line. + */ SparkSubmitCommandBuilder(List args) { this.allowsMixedArguments = false; - this.sparkArgs = new ArrayList<>(); + this.parsedArgs = new ArrayList<>(); boolean isExample = false; List submitArgs = args; + this.userArgs = Collections.emptyList(); if (args.size() > 0) { switch (args.get(0)) { @@ -131,21 +142,21 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } this.isExample = isExample; - OptionParser parser = new OptionParser(); + OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.isAppResourceReq = parser.isAppResourceReq; - } else { + this.requiresAppResource = parser.requiresAppResource; + } else { this.isExample = isExample; - this.isAppResourceReq = false; + this.requiresAppResource = false; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) { + if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) { + } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -154,9 +165,19 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); - SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + OptionParser parser = new OptionParser(false); + final boolean requiresAppResource; + + // If the user args array is not empty, we need to parse it to detect exactly what + // the user is trying to run, so that checks below are correct. + if (!userArgs.isEmpty()) { + parser.parse(userArgs); + requiresAppResource = parser.requiresAppResource; + } else { + requiresAppResource = this.requiresAppResource; + } - if (!allowsMixedArguments && isAppResourceReq) { + if (!allowsMixedArguments && requiresAppResource) { checkArgument(appResource != null, "Missing application resource."); } @@ -208,15 +229,16 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isAppResourceReq) { - checkArgument(!isExample || mainClass != null, "Missing example class name."); + if (isExample) { + checkArgument(mainClass != null, "Missing example class name."); } + if (mainClass != null) { args.add(parser.CLASS); args.add(mainClass); } - args.addAll(sparkArgs); + args.addAll(parsedArgs); if (appResource != null) { args.add(appResource); } @@ -399,7 +421,12 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean isAppResourceReq = true; + boolean requiresAppResource = true; + private final boolean errorOnUnknownArgs; + + OptionParser(boolean errorOnUnknownArgs) { + this.errorOnUnknownArgs = errorOnUnknownArgs; + } @Override protected boolean handle(String opt, String value) { @@ -443,23 +470,23 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - isAppResourceReq = false; - sparkArgs.add(opt); - sparkArgs.add(value); + requiresAppResource = false; + parsedArgs.add(opt); + parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; case VERSION: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; default: - sparkArgs.add(opt); + parsedArgs.add(opt); if (value != null) { - sparkArgs.add(value); + parsedArgs.add(value); } break; } @@ -483,12 +510,13 @@ protected boolean handleUnknown(String opt) { mainClass = className; appResource = SparkLauncher.NO_RESOURCE; return false; - } else { + } else if (errorOnUnknownArgs) { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); checkState(appResource == null, "Found unrecognized argument but resource is already set."); appResource = opt; return false; } + return true; } @Override diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d..438349e027a2 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.launcher; -import java.util.concurrent.TimeUnit; +import java.time.Duration; import org.junit.After; import org.slf4j.bridge.SLF4JBridgeHandler; @@ -47,19 +47,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9..f8dc0ec7a0bf 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -92,8 +94,8 @@ public void infoChanged(SparkAppHandle handle) { Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS); assertTrue(stopMsg instanceof Stop); } finally { - handle.kill(); close(client); + handle.kill(); client.clientThread.join(); } } @@ -143,7 +145,8 @@ public void infoChanged(SparkAppHandle handle) { assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); // Make sure the server matched the client to the handle. assertNotNull(handle.getConnection()); - close(client); + client.close(); + handle.dispose(); assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); assertEquals(SparkAppHandle.State.LOST, handle.getState()); } finally { @@ -182,6 +185,34 @@ public void testStreamFiltering() throws Exception { } } + @Test + public void testAppHandleDisconnect() throws Exception { + LauncherServer server = LauncherServer.getOrCreateServer(); + ChildProcAppHandle handle = new ChildProcAppHandle(server); + String secret = server.registerHandle(handle); + + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); + client = new TestClient(s); + client.send(new Hello(secret, "1.4.0")); + client.send(new SetAppId("someId")); + + // Wait until we know the server has received the messages and matched the handle to the + // connection before disconnecting. + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + assertEquals("someId", handle.getAppId()); + }); + + handle.disconnect(); + waitForError(client, secret); + } finally { + handle.kill(); + close(client); + client.clientThread.join(); + } + } + private void close(Closeable c) { if (c != null) { try { @@ -197,28 +228,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 043d13609fd2..53286fe93478 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 941b6eca568d..5824e463ca1a 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -565,7 +565,7 @@ class SparseVector @Since("2.0.0") ( // validate the data { - require(size >= 0, "The size of the requested sparse vector must be greater than 0.") + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 79acef8214d8..0a316f57f811 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -366,4 +366,18 @@ class VectorsSuite extends SparkMLFunSuite { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } diff --git a/mllib/pom.xml b/mllib/pom.xml index a906c9e02cd4..f07d7f24fd31 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 000000000000..f14431d50fee --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,4 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 08b0cb9b8f6a..d8f3dfa87443 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. */ - protected def predict(features: FeaturesType): Double + @Since("2.4.0") + def predict(features: FeaturesType): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index bc0b49d48d32..7e5790ab70ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -109,7 +109,7 @@ abstract class Classifier[ case None => // Get number of classes from dataset itself. val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1) - if (maxLabelRow.isEmpty) { + if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) { throw new SparkException("ML algorithm was given empty dataset.") } val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) @@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from `predictRaw()`. */ - override protected def predict(features: FeaturesType): Double = { + override def predict(features: FeaturesType): Double = { raw2prediction(predictRaw(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec5..57797d1cc497 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: Node, + @Since("1.4.0")override val rootNode: ClassificationNode, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -178,10 +178,10 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } @@ -276,9 +276,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) + val model = new DecisionTreeClassificationModel(metadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) + metadata.getAndSetParams(model) model } } @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + new DecisionTreeClassificationModel(uid, + rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f11bc1d8fe41..0aa24f0a3cfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -267,7 +267,7 @@ class GBTClassificationModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization if (isDefined(thresholds)) { super.predict(features) @@ -371,22 +371,22 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ce400f4f1faf..80c537e1e0eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -316,7 +316,7 @@ class LinearSVCModel private[classification] ( BLAS.dot(features, coefficients) + intercept } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { if (margin(features) > $(threshold)) 1.0 else 0.0 } @@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fa191604218d..06ca37bc7514 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } val isConstantLabel = histogram.count(_ != 0.0) == 1 if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { - logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + - s"will be zeros. Training is not needed.") + instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " + + s"coefficients will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") ( (coefMatrix, interceptVec, Array.empty[Double]) } else { if (!$(fitIntercept) && isConstantLabel) { - logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + + instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + s"dangerous ground, so the algorithm may not converge.") } @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") ( if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant " + "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") ( (_initialModel.interceptVector.size == numCoefficientSets) && (_initialModel.getFitIntercept == $(fitIntercept)) if (!modelIsValid) { - logWarning(s"Initial coefficients will be ignored! Its dimensions " + + instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " + s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " + s"expected size ($numCoefficientSets, $numFeatures)") } @@ -813,7 +816,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -1090,7 +1093,7 @@ class LogisticRegressionModel private[spark] ( * Predict label for the given feature vector. * The behavior of this can be adjusted using `thresholds`. */ - override protected def predict(features: Vector): Double = if (isMultinomial) { + override def predict(features: Vector): Double = if (isMultinomial) { super.predict(features) } else { // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. @@ -1267,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index fd4c98f22132..57ba47e596a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * Predict label for the given features. * This internal method is used to implement `transform()` and output [[predictionCol]]. */ - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { LabelConverter.decodeLabel(mlpModel.predict(features)) } @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d4743..45fb585ed226 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f04fde2cbbca..7df53a6b8ad1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams +private[ml] trait OneVsRestParams extends ClassifierParams with ClassifierTypeTrait with HasWeightCol { /** @@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + require(models.nonEmpty, "OneVsRestModel requires at least one model for one class") + + @Since("2.4.0") + val numClasses: Int = models.length + + @Since("2.4.0") + val numFeatures: Int = models.head.numFeatures + /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset @@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble - } + if (getRawPredictionCol != "") { + val numClass = models.length - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + // output the RawPrediction as vector + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val predArray = Array.fill[Double](numClass)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } + + // output confidence as raw prediction, label and label metadata as prediction + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output label and label metadata as prediction + aggregatedDataset + .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } } @Since("1.4.1") @@ -257,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + /** * The implementation of parallel one vs. rest runs the classification for * each class in a separate threads. @@ -448,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 78a4972adbdb..f1ef26a07d3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,23 +310,23 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + val tree = new DecisionTreeClassificationModel(treeMetadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 4c20e6563bad..438e53ba6197 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,24 +22,23 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, + BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} /** * Common params for BisectingKMeans and BisectingKMeansModel */ -private[clustering] trait BisectingKMeansParams extends Params - with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter + with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure { /** * The desired number of leaf clusters. Must be > 1. Default: 4. @@ -74,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -104,11 +103,16 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -127,9 +131,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -188,7 +192,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -248,12 +252,14 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) @@ -263,6 +269,7 @@ class BisectingKMeans @Since("2.0.0") ( .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a693..88d618c3a03a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -233,7 +234,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[Vector] = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f2af7fe082b4..97f246fbfd85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,21 +17,23 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -40,7 +42,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -90,7 +92,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -123,8 +125,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -144,22 +149,20 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +188,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +238,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,14 +258,14 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -260,7 +286,8 @@ class KMeans @Since("1.5.0") ( maxIter -> 20, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 2, - tol -> 1e-4) + tol -> 1e-4, + distanceMeasure -> DistanceMeasure.EUCLIDEAN) @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -284,6 +311,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + /** @group expertSetParam */ @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) @@ -305,16 +336,15 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) } val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) + instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -322,6 +352,7 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159..afe599cd167c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +391,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } @@ -461,7 +461,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -938,7 +939,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala new file mode 100644 index 000000000000..2c30a1d9aa94 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +/** + * Common params for PowerIterationClustering + */ +private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter + with HasPredictionCol { + + /** + * The number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.4.0") + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) + + /** @group getParam */ + @Since("2.4.0") + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use a normalized sum of similarities with other vertices. + * Default: random. + * @group expertParam + */ + @Since("2.4.0") + final val initMode = { + val allowedParams = ParamValidators.inArray(Array("random", "degree")) + new Param[String](this, "initMode", "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " + + "of similarities with other vertices. Supported options: 'random' and 'degree'.", + allowedParams) + } + + /** @group expertGetParam */ + @Since("2.4.0") + def getInitMode: String = $(initMode) + + /** + * Param for the name of the input column for vertex IDs. + * Default: "id" + * @group param + */ + @Since("2.4.0") + val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + (value: String) => value.nonEmpty) + + setDefault(idCol, "id") + + /** @group getParam */ + @Since("2.4.0") + def getIdCol: String = getOrDefault(idCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "neighbors" + * @group param + */ + @Since("2.4.0") + val neighborsCol = new Param[String](this, "neighborsCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(neighborsCol, "neighbors") + + /** @group getParam */ + @Since("2.4.0") + def getNeighborsCol: String = $(neighborsCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "similarities" + * @group param + */ + @Since("2.4.0") + val similaritiesCol = new Param[String](this, "similaritiesCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(similaritiesCol, "similarities") + + /** @group getParam */ + @Since("2.4.0") + def getSimilaritiesCol: String = $(similaritiesCol) + + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(schema, $(neighborsCol), + Seq(ArrayType(IntegerType, containsNull = false), + ArrayType(LongType, containsNull = false))) + SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), + Seq(ArrayType(FloatType, containsNull = false), + ArrayType(DoubleType, containsNull = false))) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * Lin and Cohen. From the abstract: + * PIC finds a very low-dimensional embedding of a dataset using truncated power + * iteration on a normalized pair-wise similarity matrix of the data. + * + * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix + * is a symmetric matrix whose entries are non-negative similarities between items. + * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: + * - `idCol`: vertex ID + * - `neighborsCol`: neighbors of vertex in `idCol` + * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex + * in `idCol` and each neighbor in `neighborsCol` + * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` + * containing the cluster assignment in `[0,k)` for each row (vertex). + * + * Notes: + * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. + * Transform runs the iterative PIC algorithm to cluster the whole input dataset. + * - Input validation: This validates that similarities are non-negative but does NOT validate + * that the input matrix is symmetric. + * + * @see + * Spectral clustering (Wikipedia) + */ +@Since("2.4.0") +@Experimental +class PowerIterationClustering private[clustering] ( + @Since("2.4.0") override val uid: String) + extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> "random") + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("PowerIterationClustering")) + + /** @group setParam */ + @Since("2.4.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + @Since("2.4.0") + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group setParam */ + @Since("2.4.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.4.0") + def setIdCol(value: String): this.type = set(idCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + + @Since("2.4.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + val sparkSession = dataset.sparkSession + val idColValue = $(idCol) + val rdd: RDD[(Long, Long, Double)] = + dataset.select( + col($(idCol)).cast(LongType), + col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), + col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) + ).rdd.flatMap { + case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => + require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + + s"equal to the the length of the neighbor similarity list. Row for ID " + + s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + + s"of length ${sims.length}.") + nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { + case (nbr, similarity) => (id, nbr, similarity) + } + } + val algorithm = new MLlibPowerIterationClustering() + .setK($(k)) + .setInitializationMode($(initMode)) + .setMaxIterations($(maxIter)) + val model = algorithm.run(rdd) + + val predictionsRDD: RDD[Row] = model.assignments.map { assignment => + Row(assignment.id, assignment.cluster) + } + + val predictionsSchema = StructType(Seq( + StructField($(idCol), LongType, nullable = false), + StructField($(predictionCol), IntegerType, nullable = false))) + val predictions = { + val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) + dataset.schema($(idCol)).dataType match { + case _: LongType => + uncastPredictions + case otherType => + uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + } + } + + dataset.join(predictions, $(idCol)) + } + + @Since("2.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra) +} + +@Since("2.4.0") +object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] { + + @Since("2.4.0") + override def load(path: String): PowerIterationClustering = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index d6ec5223237b..4353c46781e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,11 +20,13 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, + SchemaUtils} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -32,15 +34,11 @@ import org.apache.spark.sql.types.DoubleType * :: Experimental :: * * Evaluator for clustering results. - * The metric computes the Silhouette measure - * using the squared Euclidean distance. - * - * The Silhouette is a measure for the validation - * of the consistency within clusters. It ranges - * between 1 and -1, where a value close to 1 - * means that the points in a cluster are close - * to the other points in the same cluster and - * far from the points of the other clusters. + * The metric computes the Silhouette measure using the specified distance measure. + * + * The Silhouette is a measure for the validation of the consistency within clusters. It ranges + * between 1 and -1, where a value close to 1 means that the points in a cluster are close to the + * other points in the same cluster and far from the points of the other clusters. */ @Experimental @Since("2.3.0") @@ -84,18 +82,40 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") def setMetricName(value: String): this.type = set(metricName, value) - setDefault(metricName -> "silhouette") + /** + * param for distance measure to be used in evaluation + * (supports `"squaredEuclidean"` (default), `"cosine"`) + * @group param + */ + @Since("2.4.0") + val distanceMeasure: Param[String] = { + val availableValues = Array("squaredEuclidean", "cosine") + val allowedParams = ParamValidators.inArray(availableValues) + new Param(this, "distanceMeasure", "distance measure in evaluation. Supported options: " + + availableValues.mkString("'", "', '", "'"), allowedParams) + } + + /** @group getParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + + /** @group setParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + + setDefault(metricName -> "silhouette", distanceMeasure -> "squaredEuclidean") @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) - $(metricName) match { - case "silhouette" => + ($(metricName), $(distanceMeasure)) match { + case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol) - ) + dataset, $(predictionCol), $(featuresCol)) + case ("silhouette", "cosine") => + CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) } } } @@ -111,6 +131,57 @@ object ClusteringEvaluator } +private[evaluation] abstract class Silhouette { + + /** + * It computes the Silhouette coefficient for a point. + */ + def pointSilhouetteCoefficient( + clusterIds: Set[Double], + pointClusterId: Double, + pointClusterNumOfPoints: Long, + averageDistanceToCluster: (Double) => Double): Double = { + // Here we compute the average dissimilarity of the current point to any cluster of which the + // point is not a member. + // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current + // point - is said to be the "neighboring cluster". + val otherClusterIds = clusterIds.filter(_ != pointClusterId) + val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min + + // adjustment for excluding the node itself from the computation of the average dissimilarity + val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) { + 0.0 + } else { + averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints / + (pointClusterNumOfPoints - 1) + } + + if (currentClusterDissimilarity < neighboringClusterDissimilarity) { + 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) + } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { + (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 + } else { + 0.0 + } + } + + /** + * Compute the mean Silhouette values of all samples. + */ + def overallScore(df: DataFrame, scoreColumn: Column): Double = { + df.select(avg(scoreColumn)).collect()(0).getDouble(0) + } + + protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = { + val group = AttributeGroup.fromStructField(dataFrame.schema(columnName)) + if (group.size < 0) { + dataFrame.select(col(columnName)).first().getAs[Vector](0).size + } else { + group.size + } + } +} + /** * SquaredEuclideanSilhouette computes the average of the * Silhouette over all the data of the dataset, which is @@ -259,7 +330,7 @@ object ClusteringEvaluator * `N` is the number of points in the dataset and `W` is the number * of worker nodes. */ -private[evaluation] object SquaredEuclideanSilhouette { +private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { private[this] var kryoRegistrationPerformed: Boolean = false @@ -299,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette { df: DataFrame, predictionCol: String, featuresCol: String): Map[Double, ClusterStats] = { - val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm")) .rdd @@ -336,18 +407,19 @@ private[evaluation] object SquaredEuclideanSilhouette { * It computes the Silhouette coefficient for a point. * * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point. + * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. * @param clusterId The id of the cluster the current point belongs to. * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. * @return The Silhouette for the point. */ def computeSilhouetteCoefficient( broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], - features: Vector, + point: Vector, clusterId: Double, squaredNorm: Double): Double = { - def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = { + def compute(targetClusterId: Double): Double = { + val clusterStats = broadcastedClustersMap.value(targetClusterId) val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) squaredNorm + @@ -355,41 +427,14 @@ private[evaluation] object SquaredEuclideanSilhouette { 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints } - // Here we compute the average dissimilarity of the - // current point to any cluster of which the point - // is not a member. - // The cluster with the lowest average dissimilarity - // - i.e. the nearest cluster to the current point - - // is said to be the "neighboring cluster". - var neighboringClusterDissimilarity = Double.MaxValue - broadcastedClustersMap.value.keySet.foreach { - c => - if (c != clusterId) { - val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c)) - if(dissimilarity < neighboringClusterDissimilarity) { - neighboringClusterDissimilarity = dissimilarity - } - } - } - val currentCluster = broadcastedClustersMap.value(clusterId) - // adjustment for excluding the node itself from - // the computation of the average dissimilarity - val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) { - 0 - } else { - compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints / - (currentCluster.numOfPoints - 1) - } - - (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match { - case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) - case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 - case 0 => 0.0 - } + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId).numOfPoints, + compute) } /** - * Compute the mean Silhouette values of all samples. + * Compute the Silhouette score of the dataset using squared Euclidean distance measure. * * @param dataset The input dataset (previously clustered) on which compute the Silhouette. * @param predictionCol The name of the column which contains the predicted cluster id @@ -412,7 +457,7 @@ private[evaluation] object SquaredEuclideanSilhouette { val clustersStatsMap = SquaredEuclideanSilhouette .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol) - // Silhouette is reasonable only when the number of clusters is grater then 1 + // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) @@ -421,13 +466,194 @@ private[evaluation] object SquaredEuclideanSilhouette { computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) } - val silhouetteScore = dfWithSquaredNorm - .select(avg( - computeSilhouetteCoefficientUDF( - col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm")) - )) - .collect()(0) - .getDouble(0) + val silhouetteScore = overallScore(dfWithSquaredNorm, + computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), + col("squaredNorm"))) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} + + +/** + * The algorithm which is implemented in this object, instead, is an efficient and parallel + * implementation of the Silhouette using the cosine distance measure. The cosine distance + * measure is defined as `1 - s` where `s` is the cosine similarity between two points. + * + * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` + * is: + * + *
      + * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) + * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} + * \frac{c_{ij}}{\|C_{i}\|} + * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N + * \frac{c_{ij}}{\|C_{i}\|} \Big) + * $$ + *
      + * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension + * of the `i`-th point in cluster `$\Gamma$`. + * + * Then, we can define the vector: + * + *
      + * $$ + * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D + * $$ + *
      + * + * which can be precomputed for each point and the vector + * + *
      + * $$ + * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D + * $$ + *
      + * + * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. + * + * With these definitions, the numerator becomes: + * + *
      + * $$ + * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} + * $$ + *
      + * + * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: + * + *
      + * $$ + * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} + * $$ + *
      + * + * In the implementation, the precomputed values for the clusters are distributed among the worker + * nodes via broadcasted variables, because we can assume that the clusters are limited in number. + * + * The main strengths of this algorithm are the low computational complexity and the intrinsic + * parallelism. The precomputed information for each point and for each cluster can be computed + * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the + * dataset and `W` is the number of worker nodes. After that, every point can be analyzed + * independently from the others. + * + * For every point we need to compute the average distance to all the clusters. Since the formula + * above requires `O(D)` operations, this phase has a computational complexity which is + * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number + * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker + * nodes. + */ +private[evaluation] object CosineSilhouette extends Silhouette { + + private[this] val normalizedFeaturesColName = "normalizedFeatures" + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a + * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). + */ + def computeClusterStats( + df: DataFrame, + featuresCol: String, + predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = getNumberOfFeatures(df, featuresCol) + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) + .rdd + .map { row => (row.getDouble(0), row.getAs[Vector](1)) } + .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))( + seqOp = { + case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) => + BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum) + (normalizedFeaturesSum, numOfPoints + 1) + }, + combOp = { + case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) => + BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) + (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2) + } + ) + + clustersStatsRDD + .collectAsMap() + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the + * normalized features of the current point. + * @param clusterId The id of the cluster the current point belongs to. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]], + normalizedFeatures: Vector, + clusterId: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) + 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId)._2, + compute) + } + + /** + * Compute the Silhouette score of the dataset using the cosine distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String): Double = { + val normalizeFeatureUDF = udf { + features: Vector => { + val norm = Vectors.norm(features, 2.0) + features match { + case d: DenseVector => Vectors.dense(d.values.map(_ / norm)) + case s: SparseVector => Vectors.sparse(s.size, s.indices, s.values.map(_ / norm)) + } + } + } + val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, + normalizeFeatureUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, + predictionCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double) + } + + val silhouetteScore = overallScore(dfWithNormalizedFeatures, + computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), + col(predictionCol).cast(DoubleType))) bClustersStatsMap.destroy() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 36a46ca6ff4b..a906e954fecd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml]( private[ml] val randUnitVectors: Array[Vector]) extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { key: Vector => { @@ -230,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 8299a3e95d82..f99649f7fa16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -19,6 +19,10 @@ package org.apache.spark.ml.feature import java.{util => ju} +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Model @@ -32,11 +36,13 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * + * Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that - * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and - * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is - * only used for single column usage, and `splitsArray` is for multiple columns. + * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The + * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple + * columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) @@ -134,28 +140,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - /** - * Determines whether this `Bucketizer` is going to map multiple columns. If and only if - * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. A warning will be printed if both are set. - */ - private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol)) { - logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + - "`Bucketizer` only map one column specified by `inputCol`") - false - } else if (isSet(inputCols)) { - true - } else { - false - } - } - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + val (inputColumns, outputColumns) = if (isSet(inputCols)) { ($(inputCols).toSeq, $(outputCols).toSeq) } else { (Seq($(inputCol)), Seq($(outputCol))) @@ -170,7 +159,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val seqOfSplits = if (isBucketizeMultipleColumns()) { + val seqOfSplits = if (isSet(inputCols)) { $(splitsArray).toSeq } else { Seq($(splits)) @@ -201,9 +190,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - if (isBucketizeMultipleColumns()) { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), + Seq(outputCols, splitsArray)) + + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length && + getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).") + var transformedSchema = schema - $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) transformedSchema = SchemaUtils.appendColumn(transformedSchema, prepOutputField($(splitsArray)(idx), outputCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea..dbfb199ccd58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1ebe29703bc4..10c48c3f5208 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -69,6 +69,27 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getMinDF: Double = $(minDF) + /** + * Specifies the maximum number of different documents a term could appear in to be included + * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an + * integer greater than or equal to 1, this specifies the maximum number of documents the term + * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of + * documents the term could appear in. + * + * Default: (2^63^) - 1 + * @group param + */ + val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an integer >= 1," + + " this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum fraction of" + + " documents the term could appear in.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMaxDF: Double = $(maxDF) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) @@ -113,7 +134,11 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getBinary: Boolean = $(binary) - setDefault(vocabSize -> (1 << 18), minDF -> 1.0, minTF -> 1.0, binary -> false) + setDefault(vocabSize -> (1 << 18), + minDF -> 1.0, + maxDF -> Long.MaxValue, + minTF -> 1.0, + binary -> false) } /** @@ -142,6 +167,10 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setMinDF(value: Double): this.type = set(minDF, value) + /** @group setParam */ + @Since("2.4.0") + def setMaxDF(value: Double): this.type = set(maxDF, value) + /** @group setParam */ @Since("1.5.0") def setMinTF(value: Double): this.type = set(minTF, value) @@ -155,12 +184,24 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0 + val maybeInputSize = if (countingRequired) { + Some(input.cache().count()) + } else { + None + } val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { - $(minDF) * input.cache().count() + $(minDF) * maybeInputSize.get + } + val maxDf = if ($(maxDF) >= 1.0) { + $(maxDF) + } else { + $(maxDF) * maybeInputSize.get } - val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + require(maxDf >= minDf, "maxDF must be >= minDF.") + val allWordCounts = input.flatMap { case (tokens) => val wc = new OpenHashMap[String, Long] tokens.foreach { w => wc.changeValue(w, 1L, _ + 1L) @@ -168,11 +209,23 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) wc.map { case (word, count) => (word, (count, 1)) } }.reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2) - }.filter { case (word, (wc, df)) => - df >= minDf - }.map { case (word, (count, dfCount)) => - (word, count) - }.cache() + } + + val filteringRequired = isSet(minDF) || isSet(maxDF) + val maybeFilteredWordCounts = if (filteringRequired) { + allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf } + } else { + allWordCounts + } + + val wordCounts = maybeFilteredWordCounts + .map { case (word, (count, _)) => (word, count) } + .cache() + + if (countingRequired) { + input.unpersist() + } + val fullVocabSize = wordCounts.count() val vocab = wordCounts @@ -310,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index a918dd4c075d..d67e4819b161 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - val hashFunc: Any => Int = OldHashingTF.murmur3Hash + val hashFunc: Any => Int = FeatureHasher.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) val catCols = if (isSet(categoricalCols)) { @@ -218,4 +221,31 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { @Since("2.3.0") override def load(path: String): FeatureHasher = super.load(path) + + private val seed = OldHashingTF.seed + + /** + * Calculate a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32). + * This is the default hash algorithm used from Spark 2.0 onwards. + * Use hashUnsafeBytes2 to match the original algorithm with the value. + * See SPARK-23381. + */ + @Since("2.3.0") + private[feature] def murmur3Hash(term: Any): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) + case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb..58897cca4e5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db..1c074e204ad9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 1c9f47a0b201..a70931f783f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHParams with MLWritable { self: T => + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f6..90eceb0d61b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 145422a05919..a67a3b0abbc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -51,6 +51,14 @@ class MinHashLSHModel private[ml]( private[ml] val randCoefficients: Array[(Int, Int)]) extends LSHModel[MinHashLSHModel] { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { elems: Vector => { @@ -197,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54c..2e0ae4af66f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28..4a44f3186538 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) @@ -380,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d793..8172491a517d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1ec5f8cb6139..56e2c543d100 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.feature +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7da3339f8b48..55e595eee6ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.util._ @@ -74,7 +74,7 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * @group param */ @Since("2.3.0") - final override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen or NULL values) in features and label column of string " + "type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", @@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() + val oneHotEncodeColumns = ArrayBuffer[(String, String)]() val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() @@ -210,8 +211,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => + dataset.schema(term).dataType match { + case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() .setInputCol(term) @@ -220,6 +221,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(dataset.schema(term)) + val size = if (group.size < 0) { + dataset.select(term).first().getAs[Vector](0).size + } else { + group.size + } + encoderStages += new VectorSizeHint(uid) + .setHandleInvalid("optimistic") + .setInputCol(term) + .setSize(size) + (term, term) case _ => (term, term) } @@ -230,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - var encoder = new OneHotEncoder() - .setInputCol(indexed(term)) - .setOutputCol(encodedCol) // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. if (!hasIntercept && !keepReferenceCategory) { - encoder = encoder.setDropLast(false) + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(Array(indexed(term))) + .setOutputCols(Array(encodedCol)) + .setDropLast(false) keepReferenceCategory = true + } else { + oneHotEncodeColumns += indexed(term) -> encodedCol } - encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => @@ -253,9 +267,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) interactionCol } + if (oneHotEncodeColumns.nonEmpty) { + val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(inputCols) + .setOutputCols(outputCols) + .setDropLast(true) + } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + .setHandleInvalid($(handleInvalid)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) @@ -423,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -487,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -579,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d..91b0707dec3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab7..a833d8b270cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -234,7 +234,7 @@ class StringIndexerModel ( val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { + val (filteredDataset, keepInvalid) = $(handleInvalid) match { case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b373ae921ed3..4061154b39c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuilder +import java.util.NoSuchElementException + +import scala.collection.mutable +import scala.language.existentials import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. + * + * This requires one pass over the entire dataset. In case we need to infer column lengths from the + * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.4.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + @Since("2.4.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val vectorCols = $(inputCols).filter { c => + schema(c).dataType match { + case _: VectorUDT => true + case _ => false + } + } + val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + + val featureAttributesMap = $(inputCols).map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) + Seq(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + if (attributeGroup.attributes.isDefined) { + attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until vectorColsLengths(c)).map { i => + NumericAttribute.defaultAttr.withName(c + "_" + i) + } } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute].toArray + val lengths = featureAttributesMap.map(a => a.length).toArray + val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) + case VectorAssembler.KEEP_INVALID => (dataset, true) + case VectorAssembler.ERROR_INVALID => (dataset, false) + } // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths( + dataset: Dataset[_], + columns: Seq[String], + handleInvalid: String): Map[String, Int] = { + val groupSizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq + val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missingColumns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint + |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + .stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + groupSizes ++ firstSizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 + /** + * Returns a function that has the required information to assemble each row. + * @param lengths an array of lengths of input columns, whose size should be equal to the number + * of cells in the row (vv) + * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows + * @return a udf that can be applied on each row + */ + private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { + val indices = mutable.ArrayBuilder.make[Int] + val values = mutable.ArrayBuilder.make[Double] + var featureIndex = 0 + + var inputColumnIndex = 0 vv.foreach { case v: Double => - if (v != 0.0) { - indices += cur + if (v.isNaN && !keepInvalid) { + throw new SparkException( + s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider + |removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } else if (v != 0.0) { + indices += featureIndex values += v } - cur += 1 + inputColumnIndex += 1 + featureIndex += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { - indices += cur + i + indices += featureIndex + i values += v } } - cur += vec.size + inputColumnIndex += 1 + featureIndex += vec.size case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") + if (keepInvalid) { + val length: Int = lengths(inputColumnIndex) + Array.range(0, length).foreach { i => + indices += featureIndex + i + values += Double.NaN + } + inputColumnIndex += 1 + featureIndex += length + } else { + throw new SparkException( + s"""Encountered null while assembling a row with handleInvalid = "keep". Consider + |removing nulls from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(featureIndex, indices.result(), values.result()).compressed } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff..0e7396a621db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d..fc9996d69ba7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index aa7871d6ff29..0bf405d9abf9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel /** * Common params for FPGrowth and FPGrowthModel @@ -158,18 +159,30 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val data = dataset.select($(itemsCol)) - val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } + + if (handlePersistence) { + items.persist(StorageLevel.MEMORY_AND_DISK) + } + val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + + if (handlePersistence) { + items.unpersist() + } + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) } @@ -322,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 000000000000..02168fee16db --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +object PrefixSpan { + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{Seq[Seq[_]]}}} type + * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column + * are ignored + * @param minSupport the minimal support level of the sequential pattern, any pattern that + * appears more than (minSupport * size-of-the-dataset) times will be output + * (recommended value: `0.1`). + * @param maxPatternLength the maximal length of the sequential pattern + * (recommended value: `10`). + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the + * internal storage format) allowed in a projected database before + * local processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run + * (recommended value: `32000000`). + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns( + dataset: Dataset[_], + sequenceCol: String, + minSupport: Double, + maxPatternLength: Int, + maxLocalProjDBSize: Long): DataFrame = { + + val inputType = dataset.schema(sequenceCol).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + + val data = dataset.select(sequenceCol) + val sequences = data.where(col(sequenceCol).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(maxLocalProjDBSize) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index f7850b238465..dcc40b6668c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -169,12 +169,11 @@ object ImageSchema { var offset = 0 for (h <- 0 until height) { for (w <- 0 until width) { - val color = new Color(img.getRGB(w, h)) - + val color = new Color(img.getRGB(w, h), hasAlpha) decoded(offset) = color.getBlue.toByte decoded(offset + 1) = color.getGreen.toByte decoded(offset + 2) = color.getRed.toByte - if (nChannels == 4) { + if (hasAlpha) { decoded(offset + 3) = color.getAlpha.toByte } offset += nChannels diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4..572b8cf0051b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd2..1b7c15f1f0a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa..e6c347ed17c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -249,6 +249,75 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** + * Utility for Param validity checks for Transformers which have both single- and multi-column + * support. This utility assumes that `inputCol` indicates single-column usage and + * that `inputCols` indicates multi-column usage. + * + * This checks to ensure that exactly one set of Params has been set, and it + * raises an `IllegalArgumentException` if not. + * + * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been + * set. This does not need to include `inputCol`. + * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been + * set. This does not need to include `inputCols`. + */ + def checkSingleVsMultiColumnParams( + model: Params, + singleColumnParams: Seq[Param[_]], + multiColumnParams: Seq[Param[_]]): Unit = { + val name = s"${model.getClass.getSimpleName} $model" + + def checkExclusiveParams( + isSingleCol: Boolean, + requiredParams: Seq[Param[_]], + excludedParams: Seq[Param[_]]): Unit = { + val badParamsMsgBuilder = new mutable.StringBuilder() + + val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) + .map(_.name).mkString(", ") + if (mustUnsetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params are not applicable and should not be set: $mustUnsetParams." + } + + val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) + .map(_.name).mkString(", ") + if (mustSetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params must be defined but are not set: $mustSetParams." + } + + val badParamsMsg = badParamsMsgBuilder.toString() + + if (badParamsMsg.nonEmpty) { + val errPrefix = if (isSingleCol) { + s"$name has the inputCol Param set for single-column transform." + } else { + s"$name has the inputCols Param set for multi-column transform." + } + throw new IllegalArgumentException(s"$errPrefix $badParamsMsg") + } + } + + val inputCol = model.getParam("inputCol") + val inputCols = model.getParam("inputCols") + + if (model.isSet(inputCol)) { + require(!model.isSet(inputCols), s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but both are set.") + + checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams, + excludedParams = multiColumnParams) + } else if (model.isSet(inputCols)) { + checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams, + excludedParams = singleColumnParams) + } else { + throw new IllegalArgumentException(s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but neither is set.") + } + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... @@ -796,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -836,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a5d57a15317e..b9c3170cc3c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -63,7 +63,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + - "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), + "every 10 iterations. Note: this setting will be ignored if the checkpoint directory " + + "is not set in the SparkContext", + isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an error). More " + @@ -89,7 +91,11 @@ private[shared] object SharedParamsCodeGen { "after fitting. If set to true, then all sub-models will be available. Warning: For " + "large models, collecting all sub-models can cause OOMs on the Spark driver", Some("false"), isExpertParam = true), - ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false) + ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false), + ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), + isValid = "(value: String) => " + + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f1..282ea6ebcbf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -282,10 +282,10 @@ trait HasOutputCols extends Params { trait HasCheckpointInterval extends Params { /** - * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) @@ -504,4 +504,23 @@ trait HasLoss extends Params { /** @group getParam */ final def getLoss: String = $(loss) } + +/** + * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasDistanceMeasure extends Params { + + /** + * Param for The distance measure. Supported options: 'euclidean' and 'cosine'. + * @group param + */ + final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)) + + setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN) + + /** @group getParam */ + final def getDistanceMeasure: String = $(distanceMeasure) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0..a23f9552b9e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75..e27a96e1f5df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0291a57487c4..8bcf0793a64c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node, + override val rootNode: RegressionNode, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,10 +175,10 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int) = + private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } @@ -279,9 +279,10 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) + val model = new DecisionTreeRegressionModel(metadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) + metadata.getAndSetParams(model) model } } @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f41d15b62ddd..8598e808c494 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -230,7 +230,7 @@ class GBTRegressionModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) @@ -302,16 +302,16 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +319,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 917a4d238d46..143c8a3548b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -471,6 +472,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -488,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -497,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } @@ -725,10 +733,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +787,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( @@ -1010,7 +1014,7 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { predict(features, 0.0) } @@ -1146,7 +1150,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474..b046897ab2b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a5873d03b416..c45ade94a4e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{PipelineStage, PredictorParams} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -338,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -377,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -384,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -414,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -429,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -521,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -643,9 +650,9 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionParams with GeneralMLWritable { - def this(uid: String, coefficients: Vector, intercept: Double) = + private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -699,7 +706,7 @@ class LinearRegressionModel private[ml] ( } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { dot(features, coefficients) + intercept } @@ -710,7 +717,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -718,7 +725,50 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) +} + +/** A writer for LinearRegression that handles the "internal" (or default) format */ +private class InternalLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector, scale: Double) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[LinearRegressionModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients, scale + val data = Data(instance.intercept, instance.coefficients, instance.scale) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for LinearRegression that handles the "pmml" format */ +private class PMMLLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val sc = sparkSession.sparkContext + // Construct the MLLib model which knows how to write to PMML. + val instance = stage.asInstanceOf[LinearRegressionModel] + val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) + // Save PMML + oldModel.toPMML(sc, path) + } } @Since("1.6.0") @@ -730,22 +780,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") override def load(path: String): LinearRegressionModel = super.load(path) - /** [[MLWriter]] instance for [[LinearRegressionModel]] */ - private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends MLWriter with Logging { - - private case class Data(intercept: Double, coefficients: Vector, scale: Double) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients, scale - val data = Data(instance.intercept, instance.coefficients, instance.scale) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ @@ -771,7 +805,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 200b234b7997..4509f85aafd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. @@ -269,21 +269,21 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala new file mode 100644 index 000000000000..adf814572671 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.annotation.varargs + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.function.Function +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col + +/** + * :: Experimental :: + * + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see + * Kolmogorov-Smirnov test (Wikipedia) + */ +@Experimental +@Since("2.4.0") +object KolmogorovSmirnovTest { + + /** Used to construct output schema of test */ + private case class KolmogorovSmirnovTestResult( + pValue: Double, + statistic: Double) + + private def getSampleRDD(dataset: DataFrame, sampleCol: String): RDD[Double] = { + SchemaUtils.checkNumericType(dataset.schema, sampleCol) + import dataset.sparkSession.implicits._ + dataset.select(col(sampleCol).cast("double")).as[Double].rdd + } + + /** + * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + def test(dataset: Dataset[_], sampleCol: String, cdf: Double => Double): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset.toDF(), sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } + + /** + * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` + */ + @Since("2.4.0") + def test( + dataset: Dataset[_], + sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble) + } + + /** + * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability + * distribution equality. Currently supports the normal distribution, taking as parameters + * the mean and standard deviation. + * + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param distName a `String` name for a theoretical distribution, currently only support "norm". + * @param params `Double*` specifying the parameters to be used for the theoretical distribution. + * For "norm" distribution, the parameters includes mean and variance. + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + @varargs + def test( + dataset: Dataset[_], + sampleCol: String, distName: String, + params: Double*): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset.toDF(), sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 9bed74a9f2c0..d40827edb6d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -75,7 +75,7 @@ sealed abstract class SummaryBuilder { * val Row(meanVec) = meanDF.first() * }}} * - * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * Note: Currently, the performance of this interface is about 2x~3x slower than using the RDD * interface. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10..0242bc76698d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,15 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, - InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed abstract class Node extends Serializable { +sealed trait Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -85,35 +86,86 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + def fromOld( + oldNode: OldNode, + categoricalFeatures: Map[Int, Int], + isClassification: Boolean): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) + if (isClassification) { + new ClassificationLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + new RegressionLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, - gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + if (isClassification) { + new ClassificationInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } else { + new RegressionInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } } } } -/** - * Decision tree leaf node. - * @param prediction Prediction this node makes - * @param impurity Impurity measure at this node (for training data) - */ -class LeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { +@Since("2.4.0") +sealed trait ClassificationNode extends Node { + + /** + * Get count of training examples for specified label in this node + * @param label label number in the range [0, numClasses) + */ + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < impurityStats.stats.length, + "label should be in the range between 0 (inclusive) " + + s"and ${impurityStats.stats.length} (exclusive).") + impurityStats.stats(label) + } +} + +@Since("2.4.0") +sealed trait RegressionNode extends Node { + + /** Number of training data points in this node */ + @Since("2.4.0") + def getCount: Double = impurityStats.stats(0) + + /** Sum over training data points of the labels in this node */ + @Since("2.4.0") + def getSum: Double = impurityStats.stats(1) + + /** Sum over training data points of the square of the labels in this node */ + @Since("2.4.0") + def getSumOfSquares: Double = impurityStats.stats(2) +} + +@Since("2.4.0") +sealed trait LeafNode extends Node { + + /** Prediction this node makes. */ + def prediction: Double + + def impurity: Double override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -136,32 +188,58 @@ class LeafNode private[ml] ( override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * Decision tree leaf node for classification. + */ +@Since("2.4.0") +class ClassificationLeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with LeafNode { + override private[tree] def deepCopy(): Node = { - new LeafNode(prediction, impurity, impurityStats) + new ClassificationLeafNode(prediction, impurity, impurityStats) } } /** - * Internal Decision Tree node. - * @param prediction Prediction this node would make if it were a leaf node - * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - * @param leftChild Left-hand child node - * @param rightChild Right-hand child node - * @param split Information about the test used to split to the left or right child. + * Decision tree leaf node for regression. */ -class InternalNode private[ml] ( +@Since("2.4.0") +class RegressionLeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - val gain: Double, - val leftChild: Node, - val rightChild: Node, - val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with LeafNode { - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. + override private[tree] def deepCopy(): Node = { + new RegressionLeafNode(prediction, impurity, impurityStats) + } +} + +/** + * Internal Decision Tree node. + */ +@Since("2.4.0") +sealed trait InternalNode extends Node { + + /** + * Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + */ + def gain: Double + + /** Left-hand child node */ + def leftChild: Node + + /** Right-hand child node */ + def rightChild: Node + + /** Information about the test used to split to the left or right child. */ + def split: Split override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -206,11 +284,6 @@ class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } - - override private[tree] def deepCopy(): Node = { - new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), - split, impurityStats) - } } private object InternalNode { @@ -241,6 +314,57 @@ private object InternalNode { } } +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class ClassificationInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: ClassificationNode, + override val rightChild: ClassificationNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new ClassificationInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[ClassificationNode], + rightChild.deepCopy().asInstanceOf[ClassificationNode], + split, impurityStats) + } +} + +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class RegressionInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: RegressionNode, + override val rightChild: RegressionNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new RegressionInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[RegressionNode], + rightChild.deepCopy().asInstanceOf[RegressionNode], + split, impurityStats) + } +} + + /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -266,24 +390,53 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { + def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) + + def toClassificationNode(prune: Boolean = true): ClassificationNode = { + toNode(true, prune).asInstanceOf[ClassificationNode] + } + + def toRegressionNode(prune: Boolean = true): RegressionNode = { + toNode(false, prune).asInstanceOf[RegressionNode] + } + /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode: Node = { - if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + def toNode(isClassification: Boolean, prune: Boolean): Node = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + (leftChild.get.toNode(isClassification, prune), + rightChild.get.toNode(isClassification, prune)) match { + case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => + if (isClassification) { + new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } else { + new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } + case (l, r) => + if (isClassification) { + new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, + stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], + split.get, stats.impurityCalculator) + } else { + new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], + split.get, stats.impurityCalculator) + } + } } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + val impurity = if (stats.valid) stats.impurity else -1.0 + if (isClassification) { + new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, stats.impurityCalculator) } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + new RegressionLeafNode(stats.impurityCalculator.predict, impurity, + stats.impurityCalculator) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index a7c5f489dea8..5b14a63ada4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -95,7 +95,7 @@ private[spark] class NodeIdCache( splits: Array[Array[Split]]): Unit = { if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } prevNodeIdsForInstances = nodeIdsForInstances @@ -166,9 +166,13 @@ private[spark] class NodeIdCache( } } } + if (nodeIdsForInstances != null) { + // Unpersist current one if one exists. + nodeIdsForInstances.unpersist(false) + } if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553..905870178e54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation[_]], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -107,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample @@ -223,22 +226,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) } } } @@ -890,13 +894,7 @@ private[spark] object RandomForest extends Logging { // Sample the input only if there are continuous features. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) val sampledInput = if (continuousFeatures.nonEmpty) { - // Calculate the number of samples for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples - } else { - 1.0 - } + val fraction = samplesFractionForFindSplits(metadata) logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { @@ -918,8 +916,9 @@ private[spark] object RandomForest extends Logging { val numPartitions = math.min(continuousFeatures.length, input.partitions.length) input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) - .groupByKey(numPartitions) + .flatMap { point => + continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) + }.groupByKey(numPartitions) .map { case (idx, samples) => val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) @@ -931,7 +930,8 @@ private[spark] object RandomForest extends Logging { val numFeatures = metadata.numFeatures val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { case i if metadata.isContinuous(i) => - val split = continuousSplits(i) + // some features may contain only zero, so continuousSplits will not have a record + val split = continuousSplits.getOrElse(i, Array.empty[Split]) metadata.setNumSplits(i, split.length) split @@ -1001,11 +1001,22 @@ private[spark] object RandomForest extends Logging { } else { val numSplits = metadata.numSplits(featureIndex) - // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { - case ((m, cnt), x) => - (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + // get count for each distinct value except zero value + val partNumSamples = featureSamples.size + val partValueCountMap = scala.collection.mutable.Map[Double, Int]() + featureSamples.foreach { x => + partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 + } + + // Calculate the expected number of samples for finding splits + val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + // add expected zero value count and get complete statistics + val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { + partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + } else { + partValueCountMap.toMap } + // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -1147,4 +1158,21 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + + /** + * Calculate the subsample fraction for finding splits + * + * @param metadata decision tree metadata + * @return subsample fraction + */ + private def samplesFractionForFindSplits( + metadata: DecisionTreeMetadata): Double = { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c3617e7f..f027b14f1d47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => + case _: LeafNode => // do nothing + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession): Node = { + sparkSession: SparkSession, isClassification: Boolean): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, isClassification) } /** @@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + isClassification: Boolean): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, - n.split.getSplit, impurityStats) + if (isClassification) { + new ClassificationInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], + n.split.getSplit, impurityStats) + } else { + new RegressionInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], + n.split.getSplit, impurityStats) + } } else { - new LeafNode(n.prediction, n.impurity, impurityStats) + if (isClassification) { + new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) + } else { + new RegressionLeafNode(n.prediction, n.impurity, impurityStats) + } } finalNodes(n.id) = node } @@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String, + isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( + nodeData.toArray, impurityType, isClassification) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83..5e916cc4a9fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics @@ -232,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -268,6 +269,17 @@ class CrossValidatorModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[JList[Model[_]]]) + : CrossValidatorModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray.map(_.asScala.toArray)) + } else { + None + } + this + } + /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of @@ -411,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd1847547..13369c4df718 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -93,7 +93,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam @@ -112,7 +112,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St * for more information. * * @group expertSetParam - */@Since("2.3.0") + */ + @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") @@ -143,24 +144,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +161,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") @@ -237,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -271,6 +261,17 @@ class TrainValidationSplitModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[Model[_]]) + : TrainValidationSplitModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray) + } else { + None + } + this + } + /** * @return submodels represented in array. The index of array corresponds to the ordering of * estimatorParamMaps @@ -405,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 000000000000..6af4b3ebc2cc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 7c46f45c5971..3247c394dfa6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.util -import java.util.concurrent.atomic.AtomicLong +import java.util.UUID + +import scala.reflect.ClassTag import org.json4s._ import org.json4s.JsonDSL._ @@ -40,9 +42,10 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { + val estimator: E, + val dataset: RDD[_]) extends Logging { - private val id = Instrumentation.counter.incrementAndGet() + private val id = UUID.randomUUID() private val prefix = { val className = estimator.getClass.getSimpleName s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " @@ -56,12 +59,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } /** - * Logs a message with a prefix that uniquely identifies the training session. + * Logs a warning message with a prefix that uniquely identifies the training session. + */ + override def logWarning(msg: => String): Unit = { + super.logWarning(prefix + msg) + } + + /** + * Logs a error message with a prefix that uniquely identifies the training session. */ - def log(msg: String): Unit = { - logInfo(prefix + msg) + override def logError(msg: => String): Unit = { + super.logError(prefix + msg) } + /** + * Logs an info message with a prefix that uniquely identifies the training session. + */ + override def logInfo(msg: => String): Unit = { + super.logInfo(prefix + msg) + } + + /** + * Alias for logInfo, see above. + */ + def log(msg: String): Unit = logInfo(msg) + /** * Logs the value of the given parameters for the estimator being used in this session. */ @@ -77,11 +99,15 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } def logNumFeatures(num: Long): Unit = { - log(compact(render("numFeatures" -> num))) + logNamedValue(Instrumentation.loggerTags.numFeatures, num) } def logNumClasses(num: Long): Unit = { - log(compact(render("numClasses" -> num))) + logNamedValue(Instrumentation.loggerTags.numClasses, num) + } + + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) } /** @@ -95,6 +121,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Double): Unit = { + log(compact(render(name -> value))) + } + /** * Logs the successful completion of the training session. */ @@ -107,7 +137,14 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Some common methods for logging information about a training session. */ private[spark] object Instrumentation { - private val counter = new AtomicLong(0) + + object loggerTags { + val numFeatures = "numFeatures" + val numClasses = "numClasses" + val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" + } /** * Creates an instrumentation object for a training session. @@ -126,3 +163,56 @@ private[spark] object Instrumentation { } } + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } + } + + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + + /** + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. + */ + def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), + instr.estimator.getClass.getName.stripSuffix("$")) + } + + /** + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. + */ + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a61690780096..72a60e04360d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,9 +18,11 @@ package org.apache.spark.ml.util import java.io.IOException -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.json4s._ @@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -37,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Abstract class to be implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.4.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLFormatRegister extends MLWriterFormat { + /** + * The string that represents the format that this format provider uses. This is, along with + * stageName, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def format(): String = + * "pmml" + * }}} + * Indicates that this format is capable of saving a pmml model. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def format(): String + + /** + * The string that represents the stage type that this writer supports. This is, along with + * format, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def stageName(): String = + * "org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own PMML model. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def stageName(): String + + private[ml] def shortName(): String = s"${format()}+${stageName()}" +} + +/** + * Abstract class for utility classes that can save ML instances in Spark's internal format. */ @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { @@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + /** * Map to store extra options for this writer. */ @@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + @Since("1.6.0") + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + @Since("1.6.0") + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +@InterfaceStability.Unstable +@Since("2.4.0") +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. "pmml", "internal", or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { - shouldOverwrite = true + @Since("2.4.0") + def format(source: String): this.type = { + this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.4.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String): Unit = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) + val stageName = stage.getClass.getName + val targetName = s"$source+$stageName" + val formats = serviceLoader.asScala.toList + val shortNames = formats.map(_.shortName()) + val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => + Try(loader.loadClass(source)) match { + case Success(writer) => + // Found the ML writer using the fully qualified path + writer + case Failure(error) => + throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) + } + case head :: Nil => + head.getClass + case _ => + // Multiple sources + throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") + } + if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + writer.write(path, sparkSession, optionMap, stage) + } else { + throw new SparkException(s"ML source $source is not a valid MLWriterFormat") + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) @@ -162,6 +306,19 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * Trait for classes that provide `GeneralMLWriter`. + */ +@Since("2.4.0") +@InterfaceStability.Unstable +trait GeneralMLWritable extends MLWritable { + /** + * Returns an `MLWriter` instance for this ML instance. + */ + @Since("2.4.0") + override def write: GeneralMLWriter +} + /** * :: DeveloperApi :: * @@ -264,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -296,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -331,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -342,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -351,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -404,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620d..d9a3f85ef9a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -101,4 +102,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae5..db3f074ecfba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 9b9c70cfe510..98af487306dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -57,7 +57,8 @@ class BisectingKMeans private ( private var k: Int, private var maxIterations: Int, private var minDivisibleClusterSize: Double, - private var seed: Long) extends Logging { + private var seed: Long, + private var distanceMeasure: String) extends Logging { import BisectingKMeans._ @@ -65,7 +66,7 @@ class BisectingKMeans private ( * Constructs with the default configuration */ @Since("1.6.0") - def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN) /** * Sets the desired number of leaf clusters (default: 4). @@ -134,6 +135,22 @@ class BisectingKMeans private ( @Since("1.6.0") def getSeed: Long = this.seed + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + DistanceMeasure.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + /** * Runs the bisecting k-means algorithm. * @param input RDD of vectors @@ -147,11 +164,13 @@ class BisectingKMeans private ( } val d = input.map(_.size).first() logInfo(s"Feature dimension: $d.") + + val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure) // Compute and cache vector norms for fast distance computation. val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) - var activeClusters = summarize(d, assignments) + var activeClusters = summarize(d, assignments, dMeasure) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -184,24 +203,25 @@ class BisectingKMeans private ( val divisibleIndices = divisibleClusters.keys.toSet logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => - val (left, right) = splitCenter(summary.center, random) + val (left, right) = splitCenter(summary.center, random, dMeasure) Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map var newClusters: Map[Long, ClusterSummary] = null var newAssignments: RDD[(Long, VectorWithNorm)] = null for (iter <- 0 until maxIterations) { - newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters, + dMeasure) .filter { case (index, _) => divisibleIndices.contains(parentIndex(index)) } - newClusters = summarize(d, newAssignments) + newClusters = summarize(d, newAssignments, dMeasure) newClusterCenters = newClusters.mapValues(_.center).map(identity) } if (preIndices != null) { preIndices.unpersist(false) } preIndices = indices - indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys .persist(StorageLevel.MEMORY_AND_DISK) assignments = indices.zip(vectors) inactiveClusters ++= activeClusters @@ -222,8 +242,8 @@ class BisectingKMeans private ( } norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters - val root = buildTree(clusters) - new BisectingKMeansModel(root) + val root = buildTree(clusters, dMeasure) + new BisectingKMeansModel(root, this.distanceMeasure) } /** @@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable { */ private def summarize( d: Int, - assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { - assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + assignments: RDD[(Long, VectorWithNorm)], + distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))( seqOp = (agg, v) => agg.add(v), combOp = (agg1, agg2) => agg1.merge(agg2) ).mapValues(_.summary) @@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable { * Cluster summary aggregator. * @param d feature dimension */ - private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure) + extends Serializable { private var n: Long = 0L private val sum: Vector = Vectors.zeros(d) private var sumSq: Double = 0.0 @@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable { n += 1L // TODO: use a numerically stable approach to estimate cost sumSq += v.norm * v.norm - BLAS.axpy(1.0, v.vector, sum) + distanceMeasure.updateClusterSum(v, sum) this } @@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable { def merge(other: ClusterSummaryAggregator): this.type = { n += other.n sumSq += other.sumSq - BLAS.axpy(1.0, other.sum, sum) + distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum) this } /** Returns the summary. */ def summary: ClusterSummary = { - val mean = sum.copy - if (n > 0L) { - BLAS.scal(1.0 / n, mean) - } - val center = new VectorWithNorm(mean) - val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) - new ClusterSummary(n, center, cost) + val center = distanceMeasure.centroid(sum.copy, n) + val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq) + ClusterSummary(n, center, cost) } } @@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable { */ private def splitCenter( center: VectorWithNorm, - random: Random): (VectorWithNorm, VectorWithNorm) = { + random: Random, + distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = { val d = center.vector.size val norm = center.norm val level = 1e-4 * norm val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) - val left = center.vector.copy - BLAS.axpy(-level, noise, left) - val right = center.vector.copy - BLAS.axpy(level, noise, right) - (new VectorWithNorm(left), new VectorWithNorm(right)) + distanceMeasure.symmetricCentroids(level, noise, center.vector) } /** @@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable { private def updateAssignments( assignments: RDD[(Long, VectorWithNorm)], divisibleIndices: Set[Long], - newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + newClusterCenters: Map[Long, VectorWithNorm], + distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = { assignments.map { case (index, v) => if (divisibleIndices.contains(index)) { val children = Seq(leftChildIndex(index), rightChildIndex(index)) - val newClusterChildren = children.filter(newClusterCenters.contains(_)) + val newClusterChildren = children.filter(newClusterCenters.contains) + val newClusterChildrenCenterToId = + newClusterChildren.map(id => newClusterCenters(id) -> id).toMap + val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray if (newClusterChildren.nonEmpty) { - val selected = newClusterChildren.minBy { child => - KMeans.fastSquaredDistance(newClusterCenters(child), v) - } - (selected, v) + val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1 + val center = newClusterChildrenCenters(selected) + val id = newClusterChildrenCenterToId(center) + (id, v) } else { (index, v) } @@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable { * @param clusters a map from cluster indices to corresponding cluster summaries * @return the root node of the clustering tree */ - private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + private def buildTree( + clusters: Map[Long, ClusterSummary], + distanceMeasure: DistanceMeasure): ClusteringTreeNode = { var leafIndex = 0 var internalIndex = -1 @@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable { internalIndex -= 1 val leftIndex = leftChildIndex(rawIndex) val rightIndex = rightChildIndex(rawIndex) - val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) - val height = math.sqrt(indexes.map { childIndex => - KMeans.fastSquaredDistance(center, clusters(childIndex).center) - }.max) - val children = indexes.map(buildSubTree(_)).toArray + val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains) + val height = indexes.map { childIndex => + distanceMeasure.distance(center, clusters(childIndex).center) + }.max + val children = indexes.map(buildSubTree).toArray new ClusteringTreeNode(index, size, center, cost, height, children) } else { val index = leafIndex @@ -441,41 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] ( def center: Vector = centerWithNorm.vector /** Predicts the leaf cluster node index that the input point belongs to. */ - def predict(point: Vector): Int = { - val (index, _) = predict(new VectorWithNorm(point)) + def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = { + val (index, _) = predict(new VectorWithNorm(point), distanceMeasure) index } /** Returns the full prediction path from root to leaf. */ - def predictPath(point: Vector): Array[ClusteringTreeNode] = { - predictPath(new VectorWithNorm(point)).toArray + def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point), distanceMeasure).toArray } /** Returns the full prediction path from root to leaf. */ - private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + private def predictPath( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = { if (isLeaf) { this :: Nil } else { val selected = children.minBy { child => - KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + distanceMeasure.distance(child.centerWithNorm, pointWithNorm) } - selected :: selected.predictPath(pointWithNorm) + selected :: selected.predictPath(pointWithNorm, distanceMeasure) } } /** - * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + * Computes the cost of the input point. */ - def computeCost(point: Vector): Double = { - val (_, cost) = predict(new VectorWithNorm(point)) + def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = { + val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure) cost } /** * Predicts the cluster index and the cost of the input point. */ - private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + private def predict( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): (Int, Double) = { + predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure) } /** @@ -485,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * @return (predicted leaf cluster index, cost) */ @tailrec - private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + private def predict( + pointWithNorm: VectorWithNorm, + cost: Double, + distanceMeasure: DistanceMeasure): (Int, Double) = { if (isLeaf) { (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) - selectedChild.predict(pointWithNorm, minCost) + selectedChild.predict(pointWithNorm, minCost, distanceMeasure) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 633bda6aac80..9d115afcea75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession} */ @Since("1.6.0") class BisectingKMeansModel private[clustering] ( - private[clustering] val root: ClusteringTreeNode + private[clustering] val root: ClusteringTreeNode, + @Since("2.4.0") val distanceMeasure: String ) extends Serializable with Saveable with Logging { + @Since("1.6.0") + def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) + + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + /** * Leaf cluster centers. */ @@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(point: Vector): Int = { - root.predict(point) + root.predict(point, distanceMeasureInstance) } /** @@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(points: RDD[Vector]): RDD[Int] = { - points.map { p => root.predict(p) } + points.map { p => root.predict(p, distanceMeasureInstance) } } /** @@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(point: Vector): Double = { - root.computeCost(point) + root.computeCost(point, distanceMeasureInstance) } /** @@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: RDD[Vector]): Double = { - data.map(root.computeCost).sum() + data.map(root.computeCost(_, distanceMeasureInstance)).sum() } /** @@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { @Since("2.0.0") override def load(sc: SparkContext, path: String): BisectingKMeansModel = { - val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) - implicit val formats = DefaultFormats - val rootId = (metadata \ "rootId").extract[Int] - val classNameV1_0 = SaveLoadV1_0.thisClassName + val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path) (loadedClassName, formatVersion) match { - case (classNameV1_0, "1.0") => - val model = SaveLoadV1_0.load(sc, path, rootId) + case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) + model + case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $formatVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + + s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") } } @@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) } + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes) ++ Array(node) + } + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes(rootId) + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + private[clustering] object SaveLoadV1_0 { - private val thisFormatVersion = "1.0" + private[clustering] val thisFormatVersion = "1.0" private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" @@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) } - private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { - if (node.children.isEmpty) { - Array(node) - } else { - node.children.flatMap(getNodes(_)) ++ Array(node) - } - } - - def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode) + new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) } + } + + private[clustering] object SaveLoadV2_0 { + private[clustering] val thisFormatVersion = "2.0" - private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { - val root = nodes.get(rootId).get - if (root.children.isEmpty) { - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, new Array[ClusteringTreeNode](0)) - } else { - val children = root.children.map(c => buildTree(c, nodes)) - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, children.toArray) - } + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode, distanceMeasure) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala new file mode 100644 index 000000000000..683360efabc7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} +import org.apache.spark.mllib.util.MLUtils + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return the K-means cost of a given point against the given cluster centers. + */ + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + + /** + * @return the total cost of the cluster from its aggregated properties + */ + def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val left = centroid.copy + axpy(-level, noise, left) + val right = centroid.copy + axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = distance(point, centroid) +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary + // distance computation. + var lowerBoundOfSqDist = center.norm - point.norm + lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist + if (lowerBoundOfSqDist < bestDistance) { + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + override def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = { + EuclideanDistanceMeasure.fastSquaredDistance(point, centroid) + } +} + + +private[spark] object EuclideanDistanceMeasure { + /** + * @return the squared Euclidean distance between two vectors computed by + * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. + */ + private[clustering] def fastSquaredDistance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double = { + MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) + } +} + +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.") + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + val costVector = pointsSum.vector.copy + math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + override def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val (left, right) = super.symmetricCentroids(level, noise, centroid) + val leftVector = left.vector + val rightVector = right.vector + scal(1.0 / left.norm, leftVector) + scal(1.0 / right.norm, rightVector) + (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 49043b5acb80..b5b1be349049 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -46,14 +45,23 @@ class KMeans private ( private var initializationMode: String, private var initializationSteps: Int, private var epsilon: Double, - private var seed: Long) extends Serializable with Logging { + private var seed: Long, + private var distanceMeasure: String) extends Serializable with Logging { + + @Since("0.8.0") + private def this(k: Int, maxIterations: Int, initializationMode: String, initializationSteps: Int, + epsilon: Double, seed: Long) = + this(k, maxIterations, initializationMode, initializationSteps, + epsilon, seed, DistanceMeasure.EUCLIDEAN) /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, - * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. + * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random, + * distanceMeasure: "euclidean"}. */ @Since("0.8.0") - def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong(), + DistanceMeasure.EUCLIDEAN) /** * Number of clusters to create (k). @@ -184,6 +192,22 @@ class KMeans private ( this } + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + DistanceMeasure.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + // Initial cluster centers can be provided as a KMeansModel object rather than using the // random or k-means|| initializationMode private var initialModel: Option[KMeansModel] = None @@ -246,6 +270,8 @@ class KMeans private ( val initStartTime = System.nanoTime() + val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure) + val centers = initialModel match { case Some(kMeansCenters) => kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) @@ -253,7 +279,7 @@ class KMeans private ( if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { - initKMeansParallel(data) + initKMeansParallel(data, distanceMeasureInstance) } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 @@ -281,10 +307,9 @@ class KMeans private ( val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => - val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) + val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) - val sum = sums(bestCenter) - axpy(1.0, point.vector, sum) + distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) counts(bestCenter) += 1 } @@ -292,17 +317,17 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.mapValues { case (sum, count) => - scal(1.0 / count, sum) - new VectorWithNorm(sum) - }.collectAsMap() + }.collectAsMap().mapValues { case (sum, count) => + distanceMeasureInstance.centroid(sum, count) + } bcCenters.destroy(blocking = false) // Update the cluster centers and costs converged = true newCenters.foreach { case (j, newCenter) => - if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { + if (converged && + !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) { converged = false } centers(j) = newCenter @@ -323,7 +348,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector)) + new KMeansModel(centers.map(_.vector), distanceMeasure) } /** @@ -345,7 +370,8 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm], + distanceMeasureInstance: DistanceMeasure): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -369,7 +395,7 @@ class KMeans private ( bcNewCentersList += bcNewCenters val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - math.min(KMeans.pointCost(bcNewCenters.value, point), cost) + math.min(distanceMeasureInstance.pointCost(bcNewCenters.value, point), cost) }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs.sum() @@ -397,7 +423,9 @@ class KMeans private ( // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick k of them val bcCenters = data.context.broadcast(distinctCenters) - val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() + val countMap = data + .map(distanceMeasureInstance.findClosest(bcCenters.value, _)._1) + .countByValue() bcCenters.destroy(blocking = false) @@ -546,50 +574,6 @@ object KMeans { .run(data) } - /** - * Returns the index of the closest center to the given point, as well as the squared distance. - */ - private[mllib] def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary - // distance computation. - var lowerBoundOfSqDist = center.norm - point.norm - lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist - if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = fastSquaredDistance(center, point) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * Returns the K-means cost of a given point against the given cluster centers. - */ - private[mllib] def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = - findClosest(centers, point)._2 - - /** - * Returns the squared Euclidean distance between two vectors computed by - * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. - */ - private[clustering] def fastSquaredDistance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double = { - MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) - } - private[spark] def validateInitMode(initMode: String): Boolean = { initMode match { case KMeans.RANDOM => true @@ -601,11 +585,9 @@ object KMeans { /** * A vector with its norm for fast distance computation. - * - * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] */ -private[clustering] -class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable { +private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) + extends Serializable { def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3ad08c46d204..a78c21e838e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,12 +36,20 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) +class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String) extends Saveable with Serializable with PMMLExportable { + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("1.1.0") + def this(clusterCenters: Array[Vector]) = + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + /** * A Java-friendly constructor that takes an Iterable of Vectors. */ @@ -59,7 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec */ @Since("0.8.0") def predict(point: Vector): Int = { - KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 + distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } /** @@ -68,7 +76,8 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) - points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) + points.map(p => + distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } /** @@ -85,8 +94,9 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) - val cost = data - .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + val cost = data.map(p => + distanceMeasureInstance.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))) + .sum() bcCentersWithNorm.destroy(blocking = false) cost } @@ -94,7 +104,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { - KMeansModel.SaveLoadV1_0.save(sc, this, path) + KMeansModel.SaveLoadV2_0.save(sc, this, path) } override protected def formatVersion: String = "1.0" @@ -105,7 +115,20 @@ object KMeansModel extends Loader[KMeansModel] { @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { - KMeansModel.SaveLoadV1_0.load(sc, path) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + val classNameV2_0 = SaveLoadV2_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case (className, "2.0") if className == classNameV2_0 => + SaveLoadV2_0.load(sc, path) + case _ => throw new Exception( + s"KMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)\n" + + s" ($classNameV2_0, 2.0)") + } } private case class Cluster(id: Int, point: Vector) @@ -116,8 +139,7 @@ object KMeansModel extends Loader[KMeansModel] { } } - private[clustering] - object SaveLoadV1_0 { + private[clustering] object SaveLoadV1_0 { private val thisFormatVersion = "1.0" @@ -149,4 +171,38 @@ object KMeansModel extends Loader[KMeansModel] { new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } + + private[clustering] object SaveLoadV2_0 { + + private val thisFormatVersion = "2.0" + + private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => + Cluster(id, p.vector) + } + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centroids = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.rdd.map(Cluster.apply).collect() + assert(k == localCentroids.length) + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 53587670a5db..4a08c0a55e68 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -46,7 +46,7 @@ private[mllib] object LocalKMeans extends Logging { // Initialize centers by sampling using the k-means++ procedure. centers(0) = pickWeighted(rand, points, weights).toDense - val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0))) + val costArray = points.map(EuclideanDistanceMeasure.fastSquaredDistance(_, centers(0))) for (i <- 1 until k) { val sum = costArray.zip(weights).map(p => p._1 * p._2).sum @@ -67,11 +67,15 @@ private[mllib] object LocalKMeans extends Logging { // update costArray for (p <- points.indices) { - costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p)) + costArray(p) = math.min( + EuclideanDistanceMeasure.fastSquaredDistance(points(p), centers(i)), + costArray(p)) } } + val distanceMeasureInstance = new EuclideanDistanceMeasure + // Run up to maxIterations iterations of Lloyd's algorithm val oldClosest = Array.fill(points.length)(-1) var iteration = 0 @@ -83,7 +87,7 @@ private[mllib] object LocalKMeans extends Logging { var i = 0 while (i < points.length) { val p = points(i) - val index = KMeans.findClosest(centers, p)._1 + val index = distanceMeasureInstance.findClosest(centers, p)._1 axpy(weights(i), p.vector, sums(index)) counts(index) += weights(i) if (index != oldClosest(i)) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 9abdd44a635d..7b73b286fb91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -135,7 +135,7 @@ object HashingTF { private[HashingTF] val Murmur3: String = "murmur3" - private val seed = 42 + private[spark] val seed = 42 /** * Calculate a hash code value for the term object using the native Scala implementation. @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d1..4f2b7e6f0764 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2..7aed2f3bd8a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fd9605c01362..6e68d9684a67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -326,8 +326,6 @@ object Vectors { */ @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0, "The size of the requested sparse vector must be greater than 0.") - val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 indices.foreach { i => @@ -758,6 +756,7 @@ class SparseVector @Since("1.0.0") ( @Since("1.0.0") val indices: Array[Int], @Since("1.0.0") val values: Array[Double]) extends Vector { + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java index 43779878890d..35a250955b28 100644 --- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -42,7 +42,12 @@ public void setUp() throws IOException { @After public void tearDown() { - spark.stop(); - spark = null; + try { + spark.stop(); + spark = null; + } finally { + SparkSession.clearDefaultSession(); + SparkSession.clearActiveSession(); + } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java new file mode 100644 index 000000000000..830f668fe07b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.spark.sql.Encoders; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + + +public class JavaKolmogorovSmirnovTestSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = Arrays.asList(0.1, 1.1, 10.1, -1.1); + + dataset = spark.createDataset(points, Encoders.DOUBLE()).toDF("sample"); + } + + @Test + public void testKSTestCDF() { + // Create theoretical distributions + NormalDistribution stdNormalDist = new NormalDistribution(0, 1); + + // set seeds + Long seed = 10L; + stdNormalDist.reseedRandomGenerator(seed); + Function stdNormalCDF = (x) -> stdNormalDist.cumulativeProbability(x); + + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", stdNormalCDF).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } + + @Test + public void testKSTestNamedDistribution() { + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", "norm", 0.0, 1.0).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } +} diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 000000000000..100ef2545418 --- /dev/null +++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,3 @@ +org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 +org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 +org.apache.spark.ml.util.FakeLinearRegressionWriterWithName diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index de712079329d..87bf2be06c2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -90,6 +90,13 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } assert(e.getMessage.contains("requires integers in range")) } + val df3 = getTestData(Seq.empty[Double]) + withClue("getNumClasses should fail if dataset is empty") { + val e: SparkException = intercept[SparkException] { + c.getNumClasses(df3) + } + assert(e.getMessage == "ML algorithm was given empty dataset.") + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 98c879ece62d..d3dbb4e754d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,17 +21,16 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs import testImplicits._ @@ -62,7 +61,8 @@ class DecisionTreeClassifierSuite test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -251,20 +251,33 @@ class DecisionTreeClassifierSuite MLTestingUtils.checkCopyAndUids(dt, newTree) - val predictions = newTree.transform(newData) - .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => - assert(pred === rawPred.argmax, - s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") - val sum = rawPred.toArray.sum - assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, - "probability prediction mismatch") + testTransformer[(Vector, Double)](newData, newTree, + "prediction", "rawPrediction", "probability") { + case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, DecisionTreeClassificationModel](newTree, newData) + Vector, DecisionTreeClassificationModel](this, newTree, newData) + } + + test("prediction on single instance") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + testPredictionModelSinglePrediction(newTree, newData) } test("training with 1-category categorical feature") { @@ -280,44 +293,6 @@ class DecisionTreeClassifierSuite dt.fit(df) } - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val data = sc.parallelize(arr) - val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val dt = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(1) - .setMaxBins(3) - val model = dt.fit(df) - model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case other => - fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") - } - case other => - fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") - } - } - test("Feature importance with toy data") { val dt = new DecisionTreeClassifier() .setImpurity("gini") @@ -401,6 +376,32 @@ class DecisionTreeClassifierSuite testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("entropy") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val rootNode1 = model1.rootNode + assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val rootNode2 = model2.rootNode + assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 978f89c459f0..f0ee5496f9d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,15 +24,14 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.RegressionLeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils @@ -40,8 +39,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import GBTClassifierSuite.compareAPIs @@ -71,7 +69,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } @@ -126,14 +124,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // should predict all zeros binaryModel.setThresholds(Array(0.0, 1.0)) - val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } // should predict all ones binaryModel.setThresholds(Array(1.0, 0.0)) - val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 1.0 + } val gbtBase = new GBTClassifier val model = gbtBase.fit(df) @@ -141,15 +140,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // constant threshold scaling is the same as no thresholds binaryModel.setThresholds(Array(1.0, 1.0)) - val scaledPredictions = binaryModel.transform(df).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") { + scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) - val predictionsWithPredict = model.transform(df).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, model, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } } test("GBTClassifier: Predictor, Classifier methods") { @@ -169,61 +171,39 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val blas = BLAS.getInstance() val validationDataset = validationData.toDF(labelCol, featuresCol) - val results = gbtModel.transform(validationDataset) - // check that raw prediction is tree predictions dot tree weights - results.select(rawPredictionCol, featuresCol).collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](validationDataset, gbtModel, + "rawPrediction", "features", "probability", "prediction") { + case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) => assert(raw.size === 2) + // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) - } - // Compare rawPrediction with probability - results.select(rawPredictionCol, probabilityCol).collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 2) + // Compare rawPrediction with probability assert(prob.size === 2) // Note: we should check other loss types for classification if they are added val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) assert(prob(0) ~== predFromRaw(0) relTol eps) assert(prob(1) ~== predFromRaw(1) relTol eps) assert(prob(0) + prob(1) ~== 1.0 absTol absEps) - } - // Compare prediction with probability - results.select(predictionCol, probabilityCol).collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") - val resultsUsingRaw2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, GBTClassificationModel](this, gbtModel, validationDataset) + } - // force it to use probability2prediction - gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) - val resultsUsingProb2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } + test("prediction on single instance") { - // force it to use predict - gbtModel.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF("label", "features") + val gbtModel = gbt.fit(trainingDataset) - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, GBTClassificationModel](gbtModel, validationDataset) + testPredictionModelSinglePrediction(gbtModel, trainingDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 41a5d22dd628..c05c896df5cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -21,20 +21,18 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.HingeAggregator import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.udf -class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearSVCSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau threshold: Double, expected: Set[(Int, Double)]): Unit = { model.setThreshold(threshold) - val results = model.transform(df).select("id", "prediction").collect() - .map(r => (r.getInt(0), r.getDouble(1))) - .toSet - assert(results === expected, s"Failed for threshold = $threshold") + testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") { + rows: Seq[Row] => + val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet + assert(results === expected, s"Failed for threshold = $threshold") + } } def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { @@ -202,6 +201,12 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau dataset.as[LabeledPoint], estimator, modelEquals, 42L) } + test("prediction on single instance") { + val trainer = new LinearSVC() + val model = trainer.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(model, smallBinaryDataset) + } + test("linearSVC comparison with R e1071 and scikit-learn") { val trainer1 = new LinearSVC() .setRegParam(0.00002) // set regParam = 2.0 / datasize / c diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a5f81a38face..36b7e51f93d0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -22,22 +22,20 @@ import scala.language.existentials import scala.util.Random import scala.util.control.Breaks._ -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.LogisticAggregator import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType -class LogisticRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -332,15 +330,14 @@ class LogisticRegressionSuite val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) - val binaryZeroPredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } binaryModel.setThreshold(0.0) - val binaryOnePredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(smallMultinomialDataset) @@ -348,31 +345,36 @@ class LogisticRegressionSuite // should predict all zeros model.setThresholds(Array(1, 1000, 1000)) - val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } // should predict all ones model.setThresholds(Array(1000, 1, 1000)) - val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(onePredictions.forall(_.getDouble(0) === 1.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } // should predict all twos model.setThresholds(Array(1000, 1000, 1)) - val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(twoPredictions.forall(_.getDouble(0) === 2.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 2.0) + } // constant threshold scaling is the same as no thresholds model.setThresholds(Array(1000, 1000, 1000)) - val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1)) - val predictionsWithPredict = - model.transform(smallMultinomialDataset).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -403,21 +405,19 @@ class LogisticRegressionSuite // Modify model params, and check that the params worked. model.setThreshold(1.0) - val predAllZero = model.transform(smallBinaryDataset) - .select("prediction", "myProbability") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predAllZero.forall(_ === 0), - s"With threshold=1.0, expected predictions to be all 0, but only" + - s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model, "prediction", "myProbability") { rows => + val predAllZero = rows.map(_.getDouble(0)) + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + } // Call transform with params, and check that the params worked. - val predNotAllZero = - model.transform(smallBinaryDataset, model.threshold -> 0.0, - model.probabilityCol -> "myProb") - .select("prediction", "myProb") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predNotAllZero.exists(_ !== 0.0)) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model.copy(ParamMap(model.threshold -> 0.0, + model.probabilityCol -> "myProb")), "prediction", "myProb") { + rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0)) + } // Call fit() with new params, and check as many params as we can. lr.setThresholds(Array(0.6, 0.4)) @@ -441,10 +441,10 @@ class LogisticRegressionSuite val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallMultinomialDataset) - // check that raw prediction is coefficients dot features + intercept - results.select("rawPrediction", "features").collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), + model, "rawPrediction", "features", "probability") { + case Row(raw: Vector, features: Vector, prob: Vector) => + // check that raw prediction is coefficients dot features + intercept assert(raw.size === 3) val margins = Array.tabulate(3) { k => var margin = 0.0 @@ -455,12 +455,7 @@ class LogisticRegressionSuite margin } assert(raw ~== Vectors.dense(margins) relTol eps) - } - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 3) + // Compare rawPrediction with probability assert(prob.size === 3) val max = raw.toArray.max val subtract = if (max > 0) max else 0.0 @@ -472,39 +467,8 @@ class LogisticRegressionSuite assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps) } - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => - val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 - assert(pred == predFromProb) - } - - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallMultinomialDataset) + Vector, LogisticRegressionModel](this, model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -517,51 +481,31 @@ class LogisticRegressionSuite val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallBinaryDataset) - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), + model, "rawPrediction", "probability", "prediction") { + case Row(raw: Vector, prob: Vector, pred: Double) => + // Compare rawPrediction with probability assert(raw.size === 2) assert(prob.size === 2) val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) assert(prob(1) ~== probFromRaw1 relTol eps) assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) - } - - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallBinaryDataset) + Vector, LogisticRegressionModel](this, model, smallBinaryDataset) + } + + test("prediction on single instance") { + val blor = new LogisticRegression().setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(blorModel, smallBinaryDataset) + val mlor = new LogisticRegression().setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) + testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset) } test("coefficients and intercept methods") { @@ -616,19 +560,21 @@ class LogisticRegressionSuite LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) ).toDF() - val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() - - // probabilities are correct when margins have to be adjusted - val raw1 = results(0).getAs[Vector](0) - val prob1 = results(0).getAs[Vector](1) - assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) - assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) - // probabilities are correct when margins don't have to be adjusted - val raw2 = results(1).getAs[Vector](0) - val prob2 = results(1).getAs[Vector](1) - assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) - assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(), + model, "rawPrediction", "probability") { results: Seq[Row] => + // probabilities are correct when margins have to be adjusted + val raw1 = results(0).getAs[Vector](0) + val prob1 = results(0).getAs[Vector](1) + assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) + assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) + + // probabilities are correct when margins don't have to be adjusted + val raw2 = results(1).getAs[Vector](0) + val prob2 = results(1).getAs[Vector](1) + assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) + assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + } } test("MultiClassSummarizer") { @@ -2567,10 +2513,13 @@ class LogisticRegressionSuite val model1 = lr.fit(smallBinaryDataset) val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial") val model2 = lr2.fit(smallBinaryDataset) - val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect() - val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect() - predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect() + .map(_.getDouble(0)) + for (model <- Seq(model1, model2)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === binaryExpected + } } assert(model2.summary.totalIterations === 1) @@ -2579,10 +2528,13 @@ class LogisticRegressionSuite val lr4 = new LogisticRegression() .setInitialModel(model3).setMaxIter(5).setFamily("multinomial") val model4 = lr4.fit(smallMultinomialDataset) - val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect() - val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect() - predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction") + .collect().map(_.getDouble(0)) + for (model <- Seq(model3, model4)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === multinomialExpected + } } assert(model4.summary.totalIterations === 1) } @@ -2638,8 +2590,8 @@ class LogisticRegressionSuite LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) - val results = model.transform(constantData) - results.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantData, model, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0))) @@ -2653,8 +2605,8 @@ class LogisticRegressionSuite LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) - val resultsZero = modelZeroLabel.transform(constantZeroData) - resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) @@ -2666,8 +2618,8 @@ class LogisticRegressionSuite val constantDataWithMetadata = constantData .select(constantData("label").as("label", labelMeta), constantData("features")) val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata) - val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata) - resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index d3141ec70856..6b5fe6e49ffe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions._ -class MultilayerPerceptronClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -75,14 +70,24 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - val result = model.transform(dataset) MLTestingUtils.checkCopyAndUids(trainer, model) - val predictionAndLabels = result.select("prediction", "label").collect() - predictionAndLabels.foreach { case Row(p: Double, l: Double) => - assert(p == l) + testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") { + case Row(p: Double, l: Double) => assert(p == l) } } + test("prediction on single instance") { + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(dataset) + testPredictionModelSinglePrediction(model, dataset) + } + test("Predicted class probabilities: calibration on toy dataset") { val layers = Array[Int](4, 5, 2) @@ -99,13 +104,12 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(strongDataset) - val result = model.transform(strongDataset) - result.select("probability", "expectedProbability").collect().foreach { - case Row(p: Vector, e: Vector) => - assert(p ~== e absTol 1e-3) + testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model, + "probability", "expectedProbability") { + case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, MultilayerPerceptronClassificationModel](model, strongDataset) + Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset) } test("test model probability") { @@ -118,11 +122,10 @@ class MultilayerPerceptronClassifierSuite .setSolver("l-bfgs") val model = trainer.fit(dataset) model.setProbabilityCol("probability") - val result = model.transform(dataset) - val features2prob = udf { features: Vector => model.mlpModel.predict(features) } - result.select(features2prob(col("features")), col("probability")).collect().foreach { - case Row(p1: Vector, p2: Vector) => - assert(p1 ~== p2 absTol 1e-3) + testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") { + case Row(features: Vector, prob: Vector) => + val prob2 = model.mlpModel.predict(features) + assert(prob ~== prob2 absTol 1e-3) } } @@ -175,9 +178,6 @@ class MultilayerPerceptronClassifierSuite val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { - case Row(p: Double, l: Double) => (p, l) - } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) @@ -189,8 +189,12 @@ class MultilayerPerceptronClassifierSuite lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) - val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) - assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") { + rows: Seq[Row] => + val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1))) + val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels)) + assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + } } test("read/write: MultilayerPerceptronClassifier") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 0d3adf993383..5f9ab98a2c3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF() } - def validatePrediction(predictionAndLabels: DataFrame): Unit = { - val numOfErrorPredictions = predictionAndLabels.collect().count { + def validatePrediction(predictionAndLabels: Seq[Row]): Unit = { + val numOfErrorPredictions = predictionAndLabels.filter { case Row(prediction: Double, label: Double) => prediction != label - } + }.length // At least 80% of the predictions should be on. - assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + assert(numOfErrorPredictions < predictionAndLabels.length / 5) } def validateModelFit( @@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } def validateProbabilities( - featureAndProbabilities: DataFrame, + featureAndProbabilities: Seq[Row], model: NaiveBayesModel, modelType: String): Unit = { - featureAndProbabilities.collect().foreach { + featureAndProbabilities.foreach { case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { @@ -154,15 +153,40 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "multinomial") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "multinomial") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) + } + + test("prediction on single instance") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val trainDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(trainDataset) + + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() + + testPredictionModelSinglePrediction(model, validationDataset) } test("Naive Bayes with weighted samples") { @@ -210,15 +234,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "bernoulli") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "bernoulli") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 25bad59b9c9c..2c3417c7e402 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneVsRestSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -74,21 +72,18 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") + assert(ova.getRawPredictionCol === "rawPrediction") val ovaModel = ova.fit(dataset) MLTestingUtils.checkCopyAndUids(ova, ovaModel) - assert(ovaModel.models.length === numClasses) + assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset.select("prediction", "label").rdd.map { - row => (row.getDouble(0), row.getDouble(1)) - } - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) @@ -97,8 +92,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // determine the #confusion matrix in each class. // bound how much error we allow compared to multinomial logistic regression. val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel, + "prediction", "label") { rows => + val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) } + val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults)) + assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + } } test("one-vs-rest: tuning parallelism does not change output") { @@ -180,6 +180,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") + ovaModel.setRawPredictionCol("") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet assert(outputFields === Set("y", "fea", "pred")) @@ -191,7 +192,8 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet + === Set("label", "features", "prediction", "rawPrediction")) } test("SPARK-21306: OneVsRest should support setWeightCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index d649ceac949c..1c8c9829f18d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} @@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite { def testPredictMethods[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]( - model: M, testData: Dataset[_]): Unit = { + mlTest: MLTest, model: M, testData: Dataset[_]): Unit = { val allColModel = model.copy(ParamMap.empty) .setRawPredictionCol("rawPredictionAll") .setProbabilityCol("probabilityAll") .setPredictionCol("predictionAll") - val allColResult = allColModel.transform(testData) + + val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol)) + .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll") for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { for (probabilityCol <- Seq("", "probabilitySingle")) { @@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite { .setProbabilityCol(probabilityCol) .setPredictionCol(predictionCol) - val result = newModel.transform(allColResult) - - import org.apache.spark.sql.functions._ - - val resultRawPredictionCol = - if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) - val resultProbabilityCol = - if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) - val resultPredictionCol = - if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + import allColResult.sparkSession.implicits._ - result.select( - resultRawPredictionCol, col("rawPredictionAll"), - resultProbabilityCol, col("probabilityAll"), - resultPredictionCol, col("predictionAll") - ).collect().foreach { + mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel, + if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol, + "rawPredictionAll", + if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll", + if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll" + ) { case Row( rawPredictionSingle: Vector, rawPredictionAll: Vector, probabilitySingle: Vector, probabilityAll: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2cca2e6c0469..3062aa9f3d27 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,13 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs import testImplicits._ @@ -73,7 +71,8 @@ class RandomForestClassifierSuite test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } @@ -143,11 +142,8 @@ class RandomForestClassifierSuite MLTestingUtils.checkCopyAndUids(rf, model) - val predictions = model.transform(df) - .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", + "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") val sum = rawPred.toArray.sum @@ -155,8 +151,25 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ - Vector, RandomForestClassificationModel](model, df) + Vector, RandomForestClassificationModel](this, model, df) + } + + test("prediction on single instance") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + testPredictionModelSinglePrediction(model, df) } test("Fitting without numClasses in metadata") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fa7471fa2d65..f3ff2afcad2c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.ml.clustering -import org.apache.spark.SparkFunSuite +import scala.language.existentials + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -140,6 +145,62 @@ class BisectingKMeansSuite testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, BisectingKMeansSuite.allParamSettings, checkModelData) } + + test("BisectingKMeans with cosine distance is not supported for 0-length vectors") { + val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) + } + + test("BisectingKMeans with cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + val model = new BisectingKMeans() + .setK(3) + .setDistanceMeasure(DistanceMeasure.COSINE) + .setSeed(1) + .fit(df) + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e418..d0d461a42711 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap @@ -24,8 +26,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - +import org.apache.spark.sql.{DataFrame, Dataset, Row} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 119fe1dead9a..680a7c203408 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,19 +17,25 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.dmg.pmml.{ClusteringModel, PMML} + +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -50,6 +56,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) val model = kmeans.setMaxIter(1).fit(dataset) MLTestingUtils.checkCopyAndUids(kmeans, model) @@ -68,6 +75,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .setInitSteps(3) .setSeed(123) .setTol(1e-3) + .setDistanceMeasure(DistanceMeasure.COSINE) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") @@ -77,6 +85,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) assert(kmeans.getTol === 1e-3) + assert(kmeans.getDistanceMeasure === DistanceMeasure.COSINE) } test("parameters validation") { @@ -89,6 +98,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } + intercept[IllegalArgumentException] { + new KMeans().setDistanceMeasure("no_such_a_measure") + } } test("fit, transform and summary") { @@ -144,6 +156,67 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("KMeans using cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + + val model = new KMeans() + .setK(3) + .setSeed(1) + .setInitMode(MLlibKMeans.RANDOM) + .setTol(1e-6) + .setDistanceMeasure(DistanceMeasure.COSINE) + .fit(df) + + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } + + test("KMeans with cosine distance is not supported for 0-length vectors") { + val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) + } + + test("KMean with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -152,6 +225,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header descripiton is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { @@ -182,6 +276,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "distanceMeasure" -> DistanceMeasure.EUCLIDEAN ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76b..8d728f063dd8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite @@ -26,7 +28,6 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -323,4 +324,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 000000000000..65328df17baf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import scala.collection.mutable + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + + +class PowerIterationClusteringSuite extends SparkFunSuite + with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var data: Dataset[_] = _ + final val r1 = 1.0 + final val n1 = 10 + final val r2 = 4.0 + final val n2 = 40 + + override def beforeAll(): Unit = { + super.beforeAll() + + data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2) + } + + test("default parameters") { + val pic = new PowerIterationClustering() + + assert(pic.getK === 2) + assert(pic.getMaxIter === 20) + assert(pic.getInitMode === "random") + assert(pic.getPredictionCol === "prediction") + assert(pic.getIdCol === "id") + assert(pic.getNeighborsCol === "neighbors") + assert(pic.getSimilaritiesCol === "similarities") + } + + test("parameter validation") { + intercept[IllegalArgumentException] { + new PowerIterationClustering().setK(1) + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setIdCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setNeighborsCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setSimilaritiesCol("") + } + } + + test("power iteration clustering") { + val n = n1 + n2 + + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + val result = model.transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + + val result2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .transform(data) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + result2.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + } + assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + } + + test("supported input types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + + def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + val typedData = data.select( + col("id").cast(idType).alias("id"), + col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("similarities").cast(ArrayType(similarityType, containsNull = false)) + .alias("similarities") + ) + model.transform(typedData).collect() + } + + for (idType <- Seq(IntegerType, LongType)) { + runTest(idType, LongType, DoubleType) + } + for (neighborType <- Seq(IntegerType, LongType)) { + runTest(LongType, neighborType, DoubleType) + } + for (similarityType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, similarityType) + } + } + + test("invalid input: wrong types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id").cast(DoubleType).alias("id"), + col("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors"), + col("neighbors").alias("similarities") + ) + model.transform(typedData) + } + } + + test("invalid input: negative similarity") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(-1.0)), + (1, Array(0), Array(-1.0)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("Similarity must be nonnegative")) + } + + test("invalid input: mismatched lengths for neighbor and similarity arrays") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(0.5)), + (1, Array(0, 2), Array(0.5)), + (2, Array(1), Array(0.5)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + + "the neighbor similarity list.")) + assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + } + + test("read/write") { + val t = new PowerIterationClustering() + .setK(4) + .setMaxIter(100) + .setInitMode("degree") + .setIdCol("test_id") + .setNeighborsCol("myNeighborsCol") + .setSimilaritiesCol("mySimilaritiesCol") + .setPredictionCol("test_prediction") + testDefaultReadWrite(t) + } +} + +object PowerIterationClusteringSuite { + + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + + def generatePICData( + spark: SparkSession, + r1: Double, + r2: Double, + n1: Int, + n2: Int): DataFrame = { + // Generate two circles following the example in the PIC paper. + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + + val rows = for (i <- 1 until n) yield { + val neighbors = for (j <- 0 until i) yield { + j.toLong + } + val similarities = for (j <- 0 until i) yield { + sim(points(i), points(j)) + } + (i.toLong, neighbors.toArray, similarities.toArray) + } + + spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 677ce49a903a..2c175ff68e0b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ @@ -66,16 +68,57 @@ class ClusteringEvaluatorSuite assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } - test("number of clusters must be greater than one") { - val singleClusterDataset = irisDataset.where($"label" === 0.0) + /* + Use the following python code to load the data and evaluate it using scikit-learn package. + + from sklearn import datasets + from sklearn.metrics import silhouette_score + iris = datasets.load_iris() + round(silhouette_score(iris.data, iris.target, metric='cosine'), 10) + + 0.7222369298 + */ + test("cosine Silhouette") { val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") + .setDistanceMeasure("cosine") + + assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + } - val e = intercept[AssertionError]{ - evaluator.evaluate(singleClusterDataset) + test("number of clusters must be greater than one") { + val singleClusterDataset = irisDataset.where($"label" === 0.0) + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) + + val e = intercept[AssertionError] { + evaluator.evaluate(singleClusterDataset) + } + assert(e.getMessage.contains("Number of clusters must be greater than one")) } - assert(e.getMessage.contains("Number of clusters must be greater than one")) } + test("SPARK-23568: we should use metadata to determine features number") { + val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size + val attrGroup = new AttributeGroup("features", attributesNum) + val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label") + require(AttributeGroup.fromStructField(df.schema("features")) + .numAttributes.isDefined, "numAttributes metadata should be defined") + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + + // with the proper metadata we compute correctly the result + assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5) + + val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1) + val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()), + $"label") + // with wrong metadata the evaluator throws an Exception + intercept[SparkException](evaluator.evaluate(dfWrong)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 4455d3521087..05d4a6ee2dab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BinarizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setInputCol("feature") .setOutputCol("binarized_feature") - binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") { case Row(x: Double, y: Double) => assert(x === y, "The feature value is not correct after binarization.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 7175c721bff3..9b823259b1de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -20,16 +20,15 @@ package org.apache.spark.ml.feature import breeze.numerics.{cos, sin} import breeze.numerics.constants.Pi -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} -class BucketedRandomProjectionLSHSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -49,6 +48,14 @@ class BucketedRandomProjectionLSHSuite ParamsSuite.checkParams(model) } + test("setters") { + val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("BucketedRandomProjectionLSH: default params") { val brp = new BucketedRandomProjectionLSH assert(brp.getNumHashTables === 1.0) @@ -98,6 +105,21 @@ class BucketedRandomProjectionLSHSuite MLTestingUtils.checkCopyAndUids(brp, brpModel) } + test("BucketedRandomProjectionLSH: streaming transform") { + val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val brpModel = brp.fit(dataset) + + testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") { + case Row(values: Seq[_]) => + assert(values.length === brp.getNumHashTables) + } + } + test("BucketedRandomProjectionLSH: test of LSH property") { // Project from 2 dimensional Euclidean Space to 1 dimensions val brp = new BucketedRandomProjectionLSH() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index d9c97ae8067d..9ea15e191853 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(splits) bucketizer.setHandleInvalid("keep") - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -172,7 +171,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setSplits(Array(0.1, 0.8, 0.9)) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) } test("Bucket numeric features") { @@ -216,8 +218,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer1.isBucketizeMultipleColumns()) - bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), Seq("result1", "result2"), @@ -233,8 +233,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result")) .setSplitsArray(Array(splits(0))) - assert(bucketizer2.isBucketizeMultipleColumns()) - withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer2.transform(badDF1).collect() @@ -268,8 +266,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), Seq("expected1", "expected2")) @@ -295,8 +291,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - bucketizer.setHandleInvalid("keep") BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), @@ -335,8 +329,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - assert(t.isBucketizeMultipleColumns()) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) + assert(t.hasDefault(t.outputCol)) + assert(bucketizer.hasDefault(bucketizer.outputCol)) } test("Bucketizer in a pipeline") { @@ -348,8 +346,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) - assert(bucket.isBucketizeMultipleColumns()) - val pl = new Pipeline() .setStages(Array(bucket)) .fit(df) @@ -401,15 +397,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } - test("Both inputCol and inputCols are set") { - val bucket = new Bucketizer() - .setInputCol("feature1") - .setOutputCol("result") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setInputCols(Array("feature1", "feature2")) - - // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. - assert(bucket.isBucketizeMultipleColumns() == false) + test("assert exception is thrown if both multi-column and single-column params are set") { + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("outputCols", Array("result1", "result2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"), + ("splits", Array(-0.5, 0.0, 0.5))) + + // the following should fail because not all the params are set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1")) + ParamsSuite.testExclusiveParams(new Bucketizer, df, + ("inputCols", Array("feature1", "feature2")), + ("outputCols", Array("result1", "result2"))) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c83909c4498f..c843df9f33e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - val model = ChiSqSelectorSuite.testSelector(selector, dataset) + val model = testSelector(selector, dataset) MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fpr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fdr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fwe") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("read/write") { @@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(expected.selectedFeatures === actual.selectedFeatures) } } -} -object ChiSqSelectorSuite { - - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { - val selectorModel = selector.fit(dataset) - selectorModel.transform(dataset).select("filtered", "topFeature").collect() - .foreach { case Row(vec1: Vector, vec2: Vector) => + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(data) + testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, + "filtered", "topFeature") { + case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) - } + } selectorModel } +} + +object ChiSqSelectorSuite { /** * Mapping from all Params to valid settings which differ from the defaults. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index f213145f1ba0..61217669d927 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(cv, cvm) assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cvm.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel2.vocabulary === Array("a", "b")) - cvModel2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -113,7 +110,79 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel3.vocabulary === Array("a", "b")) - cvModel3.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer maxDF") { + val df = Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq())) + ).toDF("id", "words", "expected") + + // maxDF: ignore terms with count more than 3 + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c", "d")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // maxDF: ignore terms with freq > 0.75 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c", "d")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer using both minDF and maxDF") { + // Ignore terms with count more than 3 AND less than 2 + val df = Seq( + (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(2, Seq())) + ).toDF("id", "words", "expected") + + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(2) + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // Ignore terms with frequency higher than 0.75 AND less than 0.5 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(0.5) + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -147,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -166,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(0.3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -186,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setOutputCol("features") .setBinary(true) .fit(df) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -196,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setBinary(true) - cv2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 8dd3dd75e1be..6734336aac39 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -21,16 +21,14 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DCTSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setOutputCol("resultVec") .setInverse(inverse) - transformer.transform(dataset) - .select("resultVec", "wantedVec") - .collect() - .foreach { case Row(resultVec: Vector, wantedVec: Vector) => - assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") { + case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index a4cca27be781..3a8d0762e2ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -17,13 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.Row -class ElementwiseProductSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + test("streaming transform") { + val scalingVec = Vectors.dense(0.1, 10.0) + val data = Seq( + (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)), + (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0)) + ) + val df = spark.createDataFrame(data).toDF("features", "expected") + val ep = new ElementwiseProduct() + .setInputCol("features") + .setOutputCol("actual") + .setScalingVec(scalingVec) + testTransformer[(Vector, Vector)](df, ep, "actual", "expected") { + case Row(actual: Vector, expected: Vector) => + assert(actual ~== expected relTol 1e-14) + } + } test("read/write") { val ep = new ElementwiseProduct() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 3fc3cbb62d5b..d799ba6011fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils -class FeatureHasherSuite extends SparkFunSuite - with MLlibTestSparkContext - with DefaultReadWriteTest { +class FeatureHasherSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ - import HashingTFSuite.murmur3FeatureIdx + import FeatureHasherSuite.murmur3FeatureIdx - implicit private val vectorEncoder = ExpressionEncoder[Vector]() + implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]() test("params") { ParamsSuite.checkParams(new FeatureHasher) @@ -51,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite } test("feature hashing") { + val numFeatures = 100 + // Assume perfect hash on field names in computing expected results + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val df = Seq( - (2.0, true, "1", "foo"), - (3.0, false, "2", "bar") - ).toDF("real", "bool", "stringNum", "string") + (2.0, true, "1", "foo", + Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), + (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))), + (3.0, false, "2", "bar", + Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), + (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))) + ).toDF("real", "bool", "stringNum", "string", "expected") - val n = 100 val hasher = new FeatureHasher() .setInputCols("real", "bool", "stringNum", "string") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hasher.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - assert(attrGroup.numAttributes === Some(n)) + assert(attrGroup.numAttributes === Some(numFeatures)) - val features = output.select("features").as[Vector].collect() - // Assume perfect hash on field names - def idx: Any => Int = murmur3FeatureIdx(n) - // check expected indices - val expected = Seq( - Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), - (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))), - Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), - (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))) - ) - assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14 ) + } } test("setting explicit numerical columns to treat as categorical") { @@ -216,3 +214,11 @@ class FeatureHasherSuite extends SparkFunSuite testDefaultReadWrite(t) } } + +object FeatureHasherSuite { + + private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index a46272fdce1f..c5183ecfef7d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class HashingTFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import HashingTFSuite.murmur3FeatureIdx @@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") - val n = 100 + val numFeatures = 100 + // Assume perfect hash when computing expected features. + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val data = Seq( + ("a a b b c d".split(" ").toSeq, + Vectors.sparse(numFeatures, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))) + ) + + val df = data.toDF("words", "expected") val hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hashingTF.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = murmur3FeatureIdx(n) - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) + require(attrGroup.numAttributes === Some(numFeatures)) + + testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("applying binary term freqs") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 005edf73d29b..cdd62be43b54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class IDFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,7 +55,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((numOfData + 1.0) / (x + 1.0)) }) @@ -72,7 +70,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead MLTestingUtils.checkCopyAndUids(idfEst, idfModel) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } @@ -85,7 +83,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 }) @@ -99,7 +97,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setMinDocFreq(1) .fit(df) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index c08b35b41926..75f63a623e6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -16,13 +16,12 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.SparkException +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { val df = spark.createDataFrame( Seq( @@ -76,6 +75,28 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default ImputerSuite.iterateStrategyTest(imputer, df) } + test("Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCols(Array("value")) + .setOutputCols(Array("out")) + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -164,8 +185,6 @@ object ImputerSuite { * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { - val inputCols = imputer.getInputCols - Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 54f059e5f143..eea31fc7ae3f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class InteractionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -63,9 +63,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("numeric interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -73,14 +73,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -92,9 +93,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("nominal interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -103,14 +104,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def val df = data.select( col("a").as( "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index 918da4f9388d..8dd0f0cb91e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -14,15 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -45,9 +44,10 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setOutputCol("scaled") val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(expectedVec: Vector, actualVec: Vector) => + assert(expectedVec === actualVec, + s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec") } MLTestingUtils.checkCopyAndUids(scaler, model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 96df68dbdf05..1c2956cb8290 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{Dataset, Row} -class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class MinHashLSHSuite extends MLTest with DefaultReadWriteTest { @transient var dataset: Dataset[_] = _ @@ -43,6 +42,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa ParamsSuite.checkParams(model) } + test("setters") { + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("MinHashLSH: default params") { val rp = new MinHashLSH assert(rp.getNumHashTables === 1.0) @@ -167,4 +174,20 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(precision == 1.0) assert(recall >= 0.7) } + + test("MinHashLSHModel.transform should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + model.set(model.inputCol, "keys") + testTransformer[Tuple1[Vector]](dataset.toDF(), model, "keys", model.getOutputCol) { + case Row(_: Vector, output: Seq[_]) => + assert(output.length === model.randCoefficients.length) + // no AND-amplification yet: SPARK-18450, so each hash output is of length 1 + output.foreach { + case hashOutput: Vector => assert(hashOutput.size === 1) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 51db74eb739c..2d965f2ca2c5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -48,9 +46,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setMax(5) val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 === vector2, "Transformed vector is different with expected.") } MLTestingUtils.checkCopyAndUids(scaler, model) @@ -114,7 +112,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De val model = scaler.fit(df) model.transform(df).select("expected", "scaled").collect() .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + assert(vector1 === vector2, "Transformed vector is different with expected.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index d4975c0b4e20..201a335e0d7b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} + @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NGramSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.NGramSuite._ import testImplicits._ test("default behavior yields bigram features") { @@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setN(3) testDefaultReadWrite(t) } -} - -object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("nGrams", "wantedNGrams") - .collect() - .foreach { case Row(actualNGrams, wantedNGrams) => + def testNGram(t: NGram, dataFrame: DataFrame): Unit = { + testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { + case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) => assert(actualNGrams === wantedNGrams) - } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index c75027fb4553..eff57f1223af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,21 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NormalizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @transient var data: Array[Vector] = _ - @transient var dataFrame: DataFrame = _ - @transient var normalizer: Normalizer = _ @transient var l1Normalized: Array[Vector] = _ @transient var l2Normalized: Array[Vector] = _ @@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.dense(0.897906166, 0.113419726, 0.42532397), Vectors.sparse(3, Seq()) ) - - dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() - normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normalized_features") - } - - def collectResult(result: DataFrame): Array[Vector] = { - result.select("normalized_features").collect().map { - case Row(features: Vector) => features - } } - def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { case (v1: DenseVector, v2: DenseVector) => true case (v1: SparseVector, v2: SparseVector) => true case _ => false }, "The vector type should be preserved after normalization.") } - def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { (vector1, vector2) => - vector1 ~== vector2 absTol 1E-5 - }, "The vector value is not correct after normalization.") + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.") } test("Normalization with default parameter") { - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized") + val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected") - assertValues(result, l2Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("Normalization with setter") { - normalizer.setP(1) + val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected") + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1) - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) - - assertValues(result, l1Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("read/write") { @@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } } - -private object NormalizerSuite { - case class FeatureData(features: Vector) -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 1d3f84558642..d549e1326227 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ -class OneHotEncoderEstimatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("size")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } test("input column without ML attribute") { @@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("index")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } } test("read/write") { @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite val df = spark.createDataFrame(sc.parallelize(data), schema) - val dfWithTypes = df - .withColumn("shortInput", df("input").cast(ShortType)) - .withColumn("longInput", df("input").cast(LongType)) - .withColumn("intInput", df("input").cast(IntegerType)) - .withColumn("floatInput", df("input").cast(FloatType)) - .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) - - val cols = Array("input", "shortInput", "longInput", "intInput", - "floatInput", "decimalInput") - for (col <- cols) { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array(col)) + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoderEstimator() + .setInputCols(Array("input")) .setOutputCols(Array("output")) .setDropLast(false) - val model = encoder.fit(dfWithTypes) - val encoded = model.transform(dfWithTypes) - - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) - } + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } @@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === false) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output1", "output2")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).show - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "encoded") + } test("Can't transform on negative input") { @@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Negative value: -1.0. Input can't be negative") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Negative value: -1.0. Input can't be negative", + firstResultCol = "encoded") } test("Keep on invalid values: dropLast = false") { @@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(false) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(true) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(df) model.setDropLast(false) - val encoded1 = model.transform(df) - encoded1.select("output", "expected1").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { + case Row(output: Vector, expected1: Vector) => + assert(output === expected1) } model.setDropLast(true) - val encoded2 = model.transform(df) - encoded2.select("output", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { + case Row(output: Vector, expected2: Vector) => + assert(output === expected2) } } @@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(trainingDF) model.setHandleInvalid("error") - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Double, Vector)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "output") model.setHandleInvalid("keep") - model.transform(testDF).collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } } test("Transforming on mismatched attributes") { @@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", testAttr.toMetadata())) - val err = intercept[Exception] { - model.transform(testDF).collect() - } - err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + testTransformerByInterceptingException[(Double)]( + testDF, + model, + expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", + firstResultCol = "encoded") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index c44c6813a94b..41b32b2ffa09 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ class OneHotEncoderSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -54,16 +54,19 @@ class OneHotEncoderSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("OneHotEncoder dropLast = true") { @@ -71,16 +74,19 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(2, Seq((0, 1.0)))), + (1, Vectors.sparse(2, Seq())), + (2, Vectors.sparse(2, Seq((1, 1.0)))), + (3, Vectors.sparse(2, Seq((0, 1.0)))), + (4, Vectors.sparse(2, Seq((0, 1.0)))), + (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("input column with ML attribute") { @@ -90,20 +96,22 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) + val rows = encoder.transform(df).select("encoded").collect() + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) @@ -119,29 +127,41 @@ class OneHotEncoderSuite test("OneHotEncoder with varying types") { val df = stringIndexed() - val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) - val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", - "floatLabel", "decimalLabel") - for (col <- cols) { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = df.join(expected, "id") + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = withExpected.select(col("labelIndex") + .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected")) val encoder = new OneHotEncoder() - .setInputCol(col) + .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) - val encoded = encoder.transform(dfWithTypes) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + + testTransformer(dfWithTypes, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 3067a52a4df7..531b1d7c4d9f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PCASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pcaModel = pca.fit(df) MLTestingUtils.checkCopyAndUids(pca, pcaModel) - - pcaModel.transform(df).select("pca_features", "expected").collect().foreach { - case Row(x: Vector, y: Vector) => - assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") { + case Row(result: Vector, expected: Vector) => + assert(result ~== expected absTol 1e-5, + "Transformed vector is different with expected vector.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index e4b0ddf98bfa..0be7aa6c83f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PolynomialExpansionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,6 +55,18 @@ class PolynomialExpansionSuite -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), Vectors.sparse(19, Array.empty, Array.empty)) + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after polynomial expansion.") + } + + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.") + } + test("Polynomial expansion with default parameter") { val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") @@ -67,13 +74,10 @@ class PolynomialExpansionSuite .setInputCol("features") .setOutputCol("polyFeatures") - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -85,13 +89,10 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(3) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -103,11 +104,9 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(1) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { case Row(expanded: Vector, expected: Vector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + assertValues(expanded, expected) } } @@ -133,12 +132,13 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") for (i <- Seq(10, 11)) { - val transformed = t.setDegree(i) - .transform(df) - .select(s"expectedPoly${i}size", "polyFeatures") - .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } - - assert(transformed.collect.forall(identity)) + testTransformer[(Vector, Int, Int)]( + df, + t.setDegree(i), + s"expectedPoly${i}size", + "polyFeatures") { case Row(size: Int, expected: Vector) => + assert(size === expected.size) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index e9a75e931e6a..b009038bbd83 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.udf -class QuantileDiscretizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("Test observed number of buckets and their sizes match expected values") { val spark = this.spark @@ -38,19 +36,19 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + val model = discretizer.fit(df) - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + val relativeError = discretizer.getRelativeError + val numGoodBuckets = result.groupBy("result").count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") } test("Test on data with high proportion of duplicated values") { @@ -65,11 +63,14 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") + } } test("Test transform on data with NaN value") { @@ -88,17 +89,20 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData.toSeq.toDF("input") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result") } List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result", "expected").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double)](dataFrame, model, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -117,14 +121,17 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(5) - val result = discretizer.fit(trainDF).transform(testDF) - val firstBucketSize = result.filter(result("result") === 0.0).count - val lastBucketSize = result.filter(result("result") === 4.0).count + val model = discretizer.fit(trainDF) + testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(firstBucketSize === 30L, - s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") - assert(lastBucketSize === 31L, - s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + } } test("read/write") { @@ -132,7 +139,10 @@ class QuantileDiscretizerSuite .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setNumBuckets(6) - testDefaultReadWrite(t) + + val readDiscretizer = testDefaultReadWrite(t) + val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol") + readDiscretizer.fit(data) } test("Verify resulting model has parent") { @@ -162,21 +172,24 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) - } - - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") - - val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + val relativeError = discretizer.getRelativeError + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result + .groupBy("result" + i) + .count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}") + .count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } } } @@ -193,12 +206,16 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets == expectedNumBucket, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBucket but found ($observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } } } @@ -221,9 +238,12 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double, Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result1") } List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { @@ -232,8 +252,14 @@ class QuantileDiscretizerSuite val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { case (((a, b), c), d) => (a, b, c, d) }.toSeq.toDF("input1", "input2", "expected1", "expected2") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result1", "expected1", "result2", "expected2").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double, Double, Double)]( + dataFrame, + model, + "result1", + "expected1", + "result2", + "expected2") { case Row(x: Double, y: Double, z: Double, w: Double) => assert(x === y && w === z) } @@ -265,9 +291,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(numBucketsArray) - discretizer.fit(df).transform(df). - select("result1", "expected1", "result2", "expected2", "result3", "expected3") - .collect().foreach { + val model = discretizer.fit(df) + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df, + model, + "result1", + "expected1", + "result2", + "expected2", + "result3", + "expected3") { case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => assert(r1 === e1, s"The result value is not correct after bucketing. Expected $e1 but found $r1") @@ -319,20 +352,16 @@ class QuantileDiscretizerSuite .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) .fit(df) - val resultForMultiCols = plForMultiCols.transform(df) - .select("result1", "result2", "result3") - .collect() - - val resultForSingleCol = plForSingleCol.transform(df) - .select("result1", "result2", "result3") - .collect() + val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect() - resultForSingleCol.zip(resultForMultiCols).foreach { - case (rowForSingle, rowForMultiCols) => - assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && - rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && - rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) - } + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + plForMultiCols, + "result1", + "result2", + "result3") { rows => + assert(rows === expected) + } } test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + @@ -359,18 +388,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(Array(10, 10, 10)) - val result1 = discretizerSingleNumBuckets.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - val result2 = discretizerNumBucketsArray.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - - result1.zip(result2).foreach { - case (row1, row2) => - assert(row1.getDouble(0) == row2.getDouble(0) && - row1.getDouble(1) == row2.getDouble(1) && - row1.getDouble(2) == row2.getDouble(2)) + val model = discretizerSingleNumBuckets.fit(df) + val expected = model.transform(df).select("result1", "result2", "result3").collect() + + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + discretizerNumBucketsArray.fit(df), + "result1", + "result2", + "result3") { rows => + assert(rows === expected) } } @@ -379,7 +406,12 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBucketsArray(Array(5, 10)) - testDefaultReadWrite(discretizer) + + val readDiscretizer = testDefaultReadWrite(discretizer) + val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2") + readDiscretizer.fit(data) + assert(discretizer.hasDefault(discretizer.outputCol)) + assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } test("Multiple Columns: Both inputCol and inputCols are set") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6df..a250331efeb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,18 +17,39 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.types.DoubleType -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ + def testRFormulaTransform[A: Encoder]( + dataframe: DataFrame, + formulaModel: RFormulaModel, + expected: DataFrame, + expectedAttributes: AttributeGroup*): Unit = { + val resultSchema = formulaModel.transformSchema(dataframe.schema) + assert(resultSchema.json === expected.schema.json) + assert(resultSchema === expected.schema) + val (first +: rest) = expected.schema.fieldNames.toSeq + val expectedRows = expected.collect() + testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows.head.schema.toString() == resultSchema.toString()) + for (expectedAttributeGroup <- expectedAttributes) { + val attributeGroup = + AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name)) + assert(attributeGroup === expectedAttributeGroup) + } + assert(rows === expectedRows) + } + } + test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -38,16 +59,11 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) MLTestingUtils.checkCopyAndUids(formula, model) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) ).toDF("id", "v1", "v2", "features", "label") - // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("features column already exists") { @@ -62,9 +78,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "y", "features") val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == model.transform(original).schema.toString) + testRFormulaTransform[(Int, Double)](original, model, expected) } test("label column already exists but forceIndexLabel was set with true") { @@ -82,9 +102,11 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul intercept[IllegalArgumentException] { model.transformSchema(original.schema) } - intercept[IllegalArgumentException] { - model.transform(original) - } + testTransformerByInterceptingException[(Int, Boolean)]( + original, + model, + "Label column already exists and is not of type NumericType.", + "x") } test("allow missing label column for test datasets") { @@ -94,22 +116,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == model.transform(original).schema.toString) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "_not_y", "features") + testRFormulaTransform[(Int, Double)](original, model, expected) } test("allow empty label") { val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") - assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("encodes string terms") { @@ -117,16 +140,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("encodes string terms with string indexer order type") { @@ -164,10 +184,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { val model = formula.setStringIndexerOrderType(orderType).fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected(idx).collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } } @@ -207,10 +224,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ).toDF("id", "a", "b", "features", "label") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("formula w/o intercept, we should output reference category when encoding string terms") { @@ -243,19 +257,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model1 = formula1.fit(original) - val result1 = model1.transform(original) - val resultSchema1 = model1.transformSchema(original.schema) - // Note the column order is different between R and Spark. - val expected1 = Seq( - (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), - (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), - (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), - (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) - ).toDF("id", "a", "b", "c", "features", "label") - assert(result1.schema.toString == resultSchema1.toString) - assert(result1.collect() === expected1.collect()) - - val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( "features", Array[Attribute]( @@ -264,14 +265,20 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul new BinaryAttribute(Some("a_bar"), Some(3)), new BinaryAttribute(Some("b_zz"), Some(4)), new NumericAttribute(Some("c"), Some(5)))) - assert(attrs1 === expectedAttrs1) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1) // There is no impact for string terms interaction. val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model2 = formula2.fit(original) - val result2 = model2.transform(original) - val resultSchema2 = model2.transformSchema(original.schema) // Note the column order is different between R and Spark. val expected2 = Seq( (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), @@ -279,10 +286,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") - assert(result2.schema.toString == resultSchema2.toString) - assert(result2.collect() === expected2.collect()) - - val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( "features", Array[Attribute]( @@ -293,7 +296,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul new NumericAttribute(Some("a_bar:b_zz"), Some(5)), new NumericAttribute(Some("a_bar:b_zq"), Some(6)), new NumericAttribute(Some("c"), Some(7)))) - assert(attrs2 === expectedAttrs2) + + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2) } test("index string label") { @@ -302,15 +306,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) + val attr = NominalAttribute.defaultAttr val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") - // assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + testRFormulaTransform[(String, String, Int)](original, model, expected) } test("force to index label even it is numeric type") { @@ -319,15 +323,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( + val attr = NominalAttribute.defaultAttr + val expected = Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) - ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + .toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + testRFormulaTransform[(Double, String, Int)](original, model, expected) } test("attribute generation") { @@ -335,15 +339,20 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + .toDF("id", "a", "b", "features", "label") val expectedAttrs = new AttributeGroup( "features", Array( new BinaryAttribute(Some("a_bar"), Some(1)), new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) + } test("vector attribute generation") { @@ -351,14 +360,19 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) .toDF("id", "vec") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val attrs = new AttributeGroup("vec", 2) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)) + .toDF("id", "vec", "features", "label") + .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec_0"), Some(1)), new NumericAttribute(Some("vec_1"), Some(2)))) - assert(attrs === expectedAttrs) + + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("vector attribute generation with unnamed input attrs") { @@ -372,31 +386,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0) + ).toDF("id", "vec2", "features", "label") + .select($"id", $"vec2".as("vec2", metadata), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec2_0"), Some(1)), new NumericAttribute(Some("vec2_1"), Some(2)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs) } test("factor numeric interaction") { @@ -405,7 +419,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -414,15 +427,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("a_baz:b"), Some(1)), new NumericAttribute(Some("a_bar:b"), Some(2)), new NumericAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) } test("factor factor interaction") { @@ -430,14 +441,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val original = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + testRFormulaTransform[(Int, String, String)](original, model, expected) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( @@ -445,7 +454,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul new NumericAttribute(Some("a_bar:b_zz"), Some(2)), new NumericAttribute(Some("a_foo:b_zq"), Some(3)), new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs) } test("read/write: RFormula") { @@ -508,11 +517,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul // Handle unseen features. val formula1 = new RFormula().setFormula("id ~ a + b") - intercept[SparkException] { - formula1.fit(df1).transform(df2).collect() - } - val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) - val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula1.fit(df1), + "Unseen label:", + "features") + val model1 = formula1.setHandleInvalid("skip").fit(df1) + val model2 = formula1.setHandleInvalid("keep").fit(df1) val expected1 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), @@ -524,28 +535,84 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result1.collect() === expected1.collect()) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String)](df2, model1, expected1) + testRFormulaTransform[(Int, String, String)](df2, model2, expected2) // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") - intercept[SparkException] { - formula2.fit(df1).transform(df2).collect() - } - val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) - val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula2.fit(df1), + "Unseen label:", + "label") + + val model3 = formula2.setHandleInvalid("skip").fit(df1) + val model4 = formula2.setHandleInvalid("keep").fit(df1) + val attr = NominalAttribute.defaultAttr val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + val expected4 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) - assert(result3.collect() === expected3.collect()) - assert(result4.collect() === expected4.collect()) + testRFormulaTransform[(Int, String, String)](df2, model3, expected3) + testRFormulaTransform[(Int, String, String)](df2, model4, expected4) } + + test("Use Vectors as inputs to formula.") { + val original = Seq( + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + val (first +: rest) = Seq("id", "a", "b", "features", "label") + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + + val group = new AttributeGroup("b", 3) + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) + val model = formula.fit(dfWithMetadata) + // model should work even when applied to dataframe without metadata. + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + } + + test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") { + val d1 = Seq( + (1001L, "a"), + (1002L, "b")).toDF("id1", "c1") + val d2 = Seq[(java.lang.Long, String)]( + (20001L, "x"), + (20002L, "y"), + (null, null)).toDF("id2", "c2") + val dataset = d1.crossJoin(d2) + + def get_output(mode: String): DataFrame = { + val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode) + formula.fit(dataset).transform(dataset).select("features", "label") + } + + assert(intercept[SparkException](get_output("error").collect()) + .getMessage.contains("Encountered null while assembling a row")) + assert(get_output("skip").count() == 4) + assert(get_output("keep").count() == 6) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 673a146e619f..cf09418d8e0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.storage.StorageLevel -class SQLTransformerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class SQLTransformerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -37,14 +34,22 @@ class SQLTransformerSuite val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") - val result = sqlTrans.transform(original) - val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) - assert(original.sparkSession.catalog.listTables().count() == 0) + val resultSchema = sqlTrans.transformSchema(original.schema) + testTransformerByGlobalCheckFunc[(Int, Double, Double)]( + original, + sqlTrans, + "id", + "v1", + "v2", + "v3", + "v4") { rows => + assert(rows.head.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(rows == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) + } } test("read/write") { @@ -63,13 +68,13 @@ class SQLTransformerSuite } test("SPARK-22538: SQLTransformer should not unpersist given dataset") { - val df = spark.range(10) + val df = spark.range(10).toDF() df.cache() df.count() assert(df.storageLevel != StorageLevel.NONE) - new SQLTransformer() + val sqlTrans = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") - .transform(df) + testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => } assert(df.storageLevel != StorageLevel.NONE) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 350ba44baa1e..c5c49d67194e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class StandardScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext ) } - def assertResult(df: DataFrame): Unit = { - df.select("standardized_features", "expected").collect().foreach { - case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, - "The vector value is not correct after standardization.") - } + def assertResult: Row => Unit = { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") } test("params") { @@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext val standardScaler0 = standardScalerEst0.fit(df0) MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) - assertResult(standardScaler0.transform(df0)) + testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")( + assertResult) } test("Standardization with setter") { @@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithStd(false) .fit(df3) - assertResult(standardScaler1.transform(df1)) - assertResult(standardScaler2.transform(df2)) - assertResult(standardScaler3.transform(df3)) + testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")( + assertResult) } test("sparse data and withMean") { @@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithMean(true) .setWithStd(false) .fit(df) - assertResult(standardScaler.transform(df)) + testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")( + assertResult) } test("StandardScaler read/write") { @@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 5262b146b184..21259a50916d 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -17,28 +17,20 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - -object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("filtered", "expected") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} -class StopWordsRemoverSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { - import StopWordsRemoverSuite._ import testImplicits._ + def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = { + testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") { + case Row(tokens: Seq[_], wantedTokens: Seq[_]) => + assert(tokens === wantedTokens) + } + } + test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") @@ -151,9 +143,10 @@ class StopWordsRemoverSuite .setOutputCol(outputCol) val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) - val thrown = intercept[IllegalArgumentException] { - testStopWordsRemover(remover, dataSet) - } - assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + testTransformerByInterceptingException[(Array[String], Array[String])]( + dataSet, + remover, + s"requirement failed: Column $outputCol already exists.", + "expected") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 775a04d3df05..df2436717701 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -class StringIndexerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StringIndexerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -46,19 +43,23 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val indexerModel = indexer.fit(df) - MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - - val transformed = indexerModel.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + + testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(rows.seq === expected.collect().toSeq) + } } test("StringIndexerUnseen") { @@ -70,36 +71,38 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + // Verify we throw by default with unseen values - intercept[SparkException] { - indexer.transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer, + "Unseen label:", + "labelIndex") - indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformedSkip = indexer.transform(df2) - val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + indexer.setHandleInvalid("skip") + + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows.seq === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - // Verify that we keep the unseen records - val transformedKeep = indexer.transform(df2) - val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 - val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF() + + // Verify that we keep the unseen records + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexer with a numeric input column") { @@ -109,16 +112,14 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + assert(rows === expected.collect().toSeq) + } } test("StringIndexer with NULLs") { @@ -133,37 +134,36 @@ class StringIndexerSuite withClue("StringIndexer should throw error when setHandleInvalid=error " + "when given NULL values") { - intercept[SparkException] { - indexer.setHandleInvalid("error") - indexer.fit(df).transform(df2).collect() - } + indexer.setHandleInvalid("error") + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer.fit(df), + "StringIndexer encountered NULL value.", + "labelIndex") } indexer.setHandleInvalid("skip") - val transformedSkip = indexer.fit(df).transform(df2) - val attrSkip = Attribute - .fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + val modelSkip = indexer.fit(df) // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", "labelIndex") { rows => + val attrSkip = + Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - val transformedKeep = indexer.fit(df).transform(df2) - val attrKeep = Attribute - .fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0, null -> 2 - val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF() + val modelKeep = indexer.fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", "labelIndex") { rows => + val attrKeep = Attribute + .fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexerModel should keep silent if the input column does not exist.") { @@ -171,7 +171,9 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() - assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) + testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows => + assert(rows.toSet === df.collect().toSet) + } } test("StringIndexerModel can't overwrite output column") { @@ -188,9 +190,12 @@ class StringIndexerSuite .setOutputCol("indexedInput") .fit(df) - intercept[IllegalArgumentException] { - indexer.setOutputCol("output").transform(df) - } + testTransformerByInterceptingException[(Int, String)]( + df, + indexer.setOutputCol("output"), + "Output column output already exists.", + "labelIndex") + } test("StringIndexer read/write") { @@ -223,7 +228,8 @@ class StringIndexerSuite .setInputCol("index") .setOutputCol("actual") .setLabels(labels) - idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -234,7 +240,8 @@ class StringIndexerSuite val idxToStr1 = new IndexToString() .setInputCol("indexWithAttr") .setOutputCol("actual") - idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -252,9 +259,10 @@ class StringIndexerSuite .setInputCol("labelIndex") .setOutputCol("sameLabel") .setLabels(indexer.labels) - idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { - case Row(a: String, b: String) => - assert(a === b) + + testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", "label") { + case Row(sameLabel, label) => + assert(sameLabel === label) } } @@ -286,10 +294,11 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attrs = - NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) - assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") { rows => + val attrs = + NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } test("StringIndexer order types") { @@ -299,18 +308,17 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") - val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), - Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), - Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), - Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { - val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - assert(output === expected(idx)) + val model = indexer.setStringOrderType(orderType).fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", "labelIndex") { rows => + assert(rows === expected(idx).toDF().collect().toSeq) + } idx += 1 } } @@ -328,7 +336,11 @@ class StringIndexerSuite .setOutputCol("CITYIndexed") .fit(dfNoBristol) - val dfWithIndex = model.transform(dfNoBristol) - assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + testTransformerByGlobalCheckFunc[(String, String, String)]( + dfNoBristol, + model, + "CITYIndexed") { rows => + assert(rows.toList.count(_.getDouble(0) == 1.0) === 1) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index c895659a2d8b..be59b0af2c78 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class TokenizerSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) @@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } -class RegexTokenizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.RegexTokenizerSuite._ import testImplicits._ + def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = { + testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } + test("params") { ParamsSuite.checkParams(new RegexTokenizer) } @@ -105,14 +108,3 @@ class RegexTokenizerSuite } } -object RegexTokenizerSuite extends SparkFunSuite { - - def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("tokens", "wantedTokens") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index eca065f7e775..91fb24a268b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite @@ -31,30 +31,49 @@ class VectorAssemblerSuite import testImplicits._ + @transient var dfWithNullsAndNaNs: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sv = Vectors.sparse(2, Array(1), Array(3.0)) + dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)]( + (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (2, 1, 0.0, null, "a", sv, 6L, null), + (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null), + (4, 4, null, null, "a", sv, 9L, null), + (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null)) + .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls") + } + test("params") { ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble - assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) - assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + assert(assemble(Array(1), keepInvalid = true)(0.0) + === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0) + === Vectors.sparse(2, Array(1), Array(1.0))) val dv = Vectors.dense(2.0, 0.0) - assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) === + Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) - assert(assemble(0.0, dv, 1.0, sv) === + assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) === Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) - for (v <- Seq(1, "a", null)) { - intercept[SparkException](assemble(v)) - intercept[SparkException](assemble(1.0, v)) + for (v <- Seq(1, "a")) { + intercept[SparkException](assemble(Array(1), keepInvalid = true)(v)) + intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v)) } } test("assemble should compress vectors") { import org.apache.spark.ml.feature.VectorAssembler.assemble - val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) assert(v1.isInstanceOf[SparseVector]) - val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + val sv = Vectors.sparse(1, Array(0), Array(4.0)) + val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv) assert(v2.isInstanceOf[DenseVector]) } @@ -147,4 +166,94 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null) + === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1), keepInvalid = false)(null)) + intercept[SparkException](assemble(Array(2), keepInvalid = false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNullsAndNaNs + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val attributeZ = new AttributeGroup( + "z", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata()) + .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter) + val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata) + output.collect() + output + } + + def runWithFirstRow(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs) + output.collect() + output + } + + def runWithAllNullVectors(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode) + .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2")) + output.collect() + output + } + + // behavior when vector size hint is given + assert(runWithMetadata("keep").count() == 6, "should keep all rows") + assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls") + // should throw error with nulls + intercept[SparkException](runWithMetadata("error")) + // should throw error with NaNs + intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4")) + + // behavior when first row has information + assert(intercept[RuntimeException](runWithFirstRow("keep").count()) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls") + intercept[SparkException](runWithFirstRow("error")) + + // behavior when vector column is all null + assert(intercept[RuntimeException](runWithAllNullVectors("skip")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(intercept[NullPointerException](runWithAllNullVectors("error")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + + // behavior when scalar column is all null + assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 69a7b75e32eb..e5675e31bbec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest with Logging { +class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { import testImplicits._ import VectorIndexerSuite.FeatureData @@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(vectorIndexer, model) - model.transform(densePoints1) // should work - model.transform(sparsePoints1) // should work + testTransformer[FeatureData](densePoints1, model, "indexed") { _ => } + testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => } + // If the data is local Dataset, it throws AssertionError directly. - intercept[AssertionError] { - model.transform(densePoints2).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called on " + + "vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2, + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } // If the data is distributed Dataset, it throws SparkException // which is the wrapper of AssertionError. - intercept[SparkException] { - model.transform(densePoints2.repartition(2)).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called " + + "on vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2.repartition(2), + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } intercept[SparkException] { vectorIndexer.fit(badPoints) @@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val categoryMaps = model.categoryMaps // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) - val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) - val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) - assert(featureAttrs.name === "indexed") - assert(featureAttrs.attributes.get.length === model.numFeatures) - categoricalFeatures.foreach { feature: Int => - val origValueSet = collectedData.map(_(feature)).toSet - val targetValueIndexSet = Range(0, origValueSet.size).toSet - val catMap = categoryMaps(feature) - assert(catMap.keys.toSet === origValueSet) // Correct categories - assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices - if (origValueSet.contains(0.0)) { - assert(catMap(0.0) === 0) // value 0 gets index 0 - } - // Check transformed data - assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) - // Check metadata - val featureAttr = featureAttrs(feature) - assert(featureAttr.index.get === feature) - featureAttr match { - case attr: BinaryAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - case attr: NominalAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - assert(attr.isOrdinal.get === false) - case _ => - throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) }.toDF("indexed") + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } - } - // Check numerical feature metadata. - Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) - .foreach { feature: Int => - val featureAttr = featureAttrs(feature) - featureAttr match { - case attr: NumericAttribute => - assert(featureAttr.index.get === feature) - case _ => - throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } } } catch { @@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext (sparsePoints1, sparsePoints1TestInvalid))) { val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") val model = vectorIndexer.fit(points) - intercept[SparkException] { - model.transform(pointsTestInvalid).collect() - } + testTransformerByInterceptingException[FeatureData]( + pointsTestInvalid, + model, + "VectorIndexer encountered invalid value", + "indexed") val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") val model1 = vectorIndexer1.fit(points) - val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) - assert(transformed1 === invalidTransformed1) - + val expected = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, 1.0), + Vectors.dense(1.0, 3.0, 2.0)) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } + testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") val model2 = vectorIndexer2.fit(points) - val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - assert(invalidTransformed2 === transformed1 ++ Array( - Vectors.dense(2.0, 2.0, 0.0), - Vectors.dense(0.0, 4.0, 2.0), - Vectors.dense(1.0, 3.0, 3.0)) - ) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, "indexed") { rows => + assert(rows.map(_(0)) == expected ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0))) + } } } @@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = - model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() - points.zip(indexedPoints).foreach { - case (orig: SparseVector, indexed: SparseVector) => - assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + points.zip(rows.map(_(0))).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } } } checkSparsity(sparsePoints1, maxCategories = 2) @@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. - val indexedPoints = model.transform(densePoints1WithMeta) - val transAttributes: Array[Attribute] = - AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get - featureAttributes.zip(transAttributes).foreach { case (orig, trans) => - assert(orig.name === trans.name) - (orig, trans) match { - case (orig: NumericAttribute, trans: NumericAttribute) => - assert(orig.max.nonEmpty && orig.max === trans.max) - case _ => + testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, "indexed") { rows => + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => // do nothing // TODO: Once input features marked as categorical are handled correctly, check that here. + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala index f6c9a76599fa..d89d10b320d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamTest class VectorSizeHintSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,16 +38,23 @@ class VectorSizeHintSuite val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") val noSizeTransformer = new VectorSizeHint().setInputCol("vector") - intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noSizeTransformer, + "Failed to find a default value for size", + "vector") intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) val noInputColTransformer = new VectorSizeHint().setSize(2) - intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noInputColTransformer, + "Failed to find a default value for inputCol", + "vector") intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) } test("Adding size to column of vectors.") { - val size = 3 val vectorColName = "vector" val denseVector = Vectors.dense(1, 2, 3) @@ -66,12 +71,15 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrame) - assert( - AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, - "Transformer did not add expected size data.") - val numRows = withSize.collect().length - assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) { + rows => { + assert( + AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = rows.length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } } } @@ -93,14 +101,16 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrameWithMetadata) - - val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) - assert(newGroup.size === size, "Column has incorrect size metadata.") - assert( - newGroup.attributes.get === group.attributes.get, - "VectorSizeHint did not preserve attributes.") - withSize.collect + testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + vectorColName) { rows => + val newGroup = AttributeGroup.fromStructField(rows.head.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + } } } @@ -120,7 +130,11 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + testTransformerByInterceptingException[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + "Trying to set size of vectors in `vector` to 4 but size already set to 3.", + vectorColName) } } @@ -136,18 +150,36 @@ class VectorSizeHintSuite .setHandleInvalid("error") .setSize(3) - intercept[SparkException](sizeHint.transform(dataWithNull).collect()) - intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithNull, + sizeHint, + "Got null vector in VectorSizeHint", + "vector") + + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithShort, + sizeHint, + "VectorSizeHint Expecting a vector of size 3 but got 1", + "vector") sizeHint.setHandleInvalid("skip") - assert(sizeHint.transform(dataWithNull).count() === 1) - assert(sizeHint.transform(dataWithShort).count() === 1) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 1) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 1) + } sizeHint.setHandleInvalid("optimistic") - assert(sizeHint.transform(dataWithNull).count() === 2) - assert(sizeHint.transform(dataWithShort).count() === 2) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 2) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 2) + } } + test("read/write") { val sizeHint = new VectorSizeHint() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 1746ce53107c..3d90f9d9ac76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StructField, StructType} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { val slicer = new VectorSlicer().setInputCol("feature") @@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") - def validateResults(df: DataFrame): Unit = { - df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + def validateResults(rows: Seq[Row]): Unit = { + rows.foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 === vec2) } - val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) - val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected")) assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => assert(a === b) @@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 6183606a7b2a..b59c4e796733 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class Word2VecSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) - model.transform(docDF).select("result", "expected").collect().foreach { + testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } test("getVectors") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val spark = this.spark - import spark.implicits._ - val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -222,12 +208,11 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val oldModel = new OldWord2VecModel(word2VecMap) val instance = new Word2VecModel("myWord2VecModel", oldModel) val newInstance = testDefaultReadWrite(instance) - assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + assert(newInstance.getVectors.collect().sortBy(_.getString(0)) === + instance.getVectors.collect().sortBy(_.getString(0))) } test("Word2Vec works with input that is non-nullable (NGram)") { - val spark = this.spark - import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") @@ -242,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(ngramDF) // Just test that this transformation succeeds - model.transform(ngramDF).collect() + testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, model, "result") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 000000000000..9e538696cbcf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", + minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index dba61cd1eb1c..527b3f895596 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -53,11 +53,11 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(df.count === 1) df = readImages(imagePath, null, true, -1, false, 1.0, 0) - assert(df.count === 9) + assert(df.count === 10) df = readImages(imagePath, null, true, -1, true, 1.0, 0) val countTotal = df.count - assert(countTotal === 7) + assert(countTotal === 8) df = readImages(imagePath, null, true, -1, true, 0.5, 0) // Random number about half of the size of the original dataset @@ -65,11 +65,71 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(count50 > 0 && count50 < countTotal) } + test("readImages test: recursive = false") { + val df = readImages(imagePath, null, false, 3, true, 1.0, 0) + assert(df.count() === 0) + } + + test("readImages test: read jpg image") { + val df = readImages(imagePath + "/kittens/DP153539.jpg", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read png image") { + val df = readImages(imagePath + "/multi-channel/BGRA.png", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read non image") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, true, 1.0, 0) + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + assert(df.count() === 0) + } + + test("readImages test: read non image and dropImageFailures is false") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, false, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: sampleRatio > 1") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, 1.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio < 0") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, -0.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio = 0") { + val df = readImages(imagePath, null, true, 3, true, 0.0, 0) + assert(df.count() === 0) + } + + test("readImages test: with sparkSession") { + val df = readImages(imagePath, sparkSession = spark, true, 3, true, 1.0, 0) + assert(df.count() === 8) + } + test("readImages partition test") { val df = readImages(imagePath, null, true, 3, true, 1.0, 0) assert(df.rdd.getNumPartitions === 3) } + test("readImages partition test: < 0") { + val df = readImages(imagePath, null, true, -3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + + test("readImages partition test: = 0") { + val df = readImages(imagePath, null, true, 0, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + // Images with the different number of channels test("readImages pixel values test") { @@ -93,7 +153,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // - default representation for 3-channel RGB images is BGR row-wise: // (B00, G00, R00, B10, G10, R10, ...) // - default representation for 4-channel RGB images is BGRA row-wise: - // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) private val firstBytes20 = Map( "grayscale.jpg" -> (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, @@ -103,6 +163,9 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { -71, -58, -56, -73, -64))), "BGRA.png" -> (("CV_8UC4", Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, - -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))), + "BGRA_alpha_60.png" -> (("CV_8UC4", + Array[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128, + -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60))) ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 85198ad4c913..36e06091d24d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.ml.param import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.MyParams +import org.apache.spark.sql.Dataset class ParamsSuite extends SparkFunSuite { @@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite { require(copyReturnType === obj.getClass, s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } + + /** + * Checks that the class throws an exception in case multiple exclusive params are set. + * The params to be checked are passed as arguments with their value. + */ + def testExclusiveParams( + model: Params, + dataset: Dataset[_], + paramsAndValues: (String, Any)*): Unit = { + val m = model.copy(ParamMap.empty) + paramsAndValues.foreach { case (paramName, paramValue) => + m.set(m.getParam(paramName), paramValue) + } + intercept[IllegalArgumentException] { + m match { + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index addcd21d50aa..e3dfe2faf569 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,8 +22,7 @@ import java.util.Random import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.WrappedArray +import scala.collection.mutable.{ArrayBuffer, WrappedArray} import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -35,21 +34,20 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.recommendation.ALS.Rating -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class ALSSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { +class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() @@ -413,34 +411,36 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => - (rating.toDouble, prediction.toDouble) + testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble)) + + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the + // weighted RMSE with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val errorSquares = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + } + val mse = errorSquares.sum / errorSquares.length + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) } - val rmse = - if (implicitPrefs) { - // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. - // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE - // with the confidence scores as weights. - val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => - val confidence = 1.0 + alpha * math.abs(rating) - val rating01 = math.max(math.min(rating, 1.0), 0.0) - val prediction01 = math.max(math.min(prediction, 1.0), 0.0) - val err = prediction01 - rating01 - (confidence, confidence * err * err) - }.reduce { case ((c0, e0), (c1, e1)) => - (c0 + c1, e0 + e1) - } - math.sqrt(weightedSumSq / totalWeight) - } else { - val mse = predictions.map { case (rating, prediction) => - val err = rating - prediction - err * err - }.mean() - math.sqrt(mse) - } - logInfo(s"Test RMSE is $rmse.") - assert(rmse < targetRMSE) MLTestingUtils.checkCopyAndUids(als, model) } @@ -586,6 +586,68 @@ class ALSSuite allModelParamSettings, checkModelData) } + private def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val df = dfs.find { + case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType + } match { + case Some((_, df)) => df + } + val expected = estimator.fit(df) + val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } + + val baseDF = dfs.find(_._1.numericType == baseType).get._2 + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + private class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Int, Double)]) + + private def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String) = { + + import testImplicits._ + + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col) + val types = + Seq(new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()) + ) + types.map { t => + val cols = Seq(col(column).cast(t.numericType)) ++ others + t -> df.select(cols: _*) + } + } + test("input type validation") { val spark = this.spark import spark.implicits._ @@ -595,12 +657,16 @@ class ALSSuite val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => - ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) - } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1) + } { (ex, act, df, enc) => + val expected = ex.transform(df).selectExpr("prediction") + .first().getFloat(0) + testTransformerByGlobalCheckFunc(df, act, "prediction") { + case rows: Seq[Row] => + expected ~== rows.head.getFloat(0) absTol 1e-6 + }(enc) } } // check user/item ids falling outside of Int range @@ -628,18 +694,22 @@ class ALSSuite } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) - assert(intercept[SparkException] { - model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains(msg)) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + assert(intercept[SparkException] { + model.transform(dataFrame).first + }.getMessage.contains(msg)) + assert(intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + }.getMessage.contains(msg)) + } + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } } @@ -662,28 +732,31 @@ class ALSSuite val knownItem = data.select(max("item")).as[Int].first() val unknownItem = knownItem + 20 val test = Seq( - (unknownUser, unknownItem), - (knownUser, unknownItem), - (unknownUser, knownItem), - (knownUser, knownItem) - ).toDF("user", "item") + (unknownUser, unknownItem, true), + (knownUser, unknownItem, true), + (unknownUser, knownItem, true), + (knownUser, knownItem, false) + ).toDF("user", "item", "expectedIsNaN") val als = new ALS().setMaxIter(1).setRank(1) // default is 'nan' val defaultModel = als.fit(data) - val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() - assert(defaultPredictions.length == 4) - assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) - assert(!defaultPredictions.last.isNaN) + testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") { + case Row(expectedIsNaN: Boolean, prediction: Float) => + assert(prediction.isNaN === expectedIsNaN) + } // check 'drop' strategy should filter out rows with unknown users/items - val dropPredictions = defaultModel - .setColdStartStrategy("drop") - .transform(test) - .select("prediction").as[Float].collect() - assert(dropPredictions.length == 1) - assert(!dropPredictions.head.isNaN) - assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + val defaultPrediction = defaultModel.transform(test).select("prediction") + .as[Float].filter(!_.isNaN).first() + testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test, + defaultModel.setColdStartStrategy("drop"), "prediction") { + case rows: Seq[Row] => + val dropPredictions = rows.map(_.getFloat(0)) + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPrediction relTol 1e-14) + } } test("case insensitive cold start param value") { @@ -693,7 +766,7 @@ class ALSSuite val data = ratings.toDF val model = new ALS().fit(data) Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => - model.setColdStartStrategy(s).transform(data) + testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 68a1218c23ec..9ae27339b11d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { assert(importances.toArray.forall(_ >= 0.0)) } + test("prediction on single instance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( @@ -176,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode + + assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 + && statInfo.getSumOfSquares == 600.0) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 11c593b521e6..fad11d078250 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(trainData.toDF()) + testPredictionModelSinglePrediction(model, validationData.toDF) + } + test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ef2ff94a5e21..997c50157dcd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(model.getLink === "identity") } + test("prediction on single instance") { + val glr = new GeneralizedLinearRegression + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + testPredictionModelSinglePrediction(model, datasetGaussianIdentity) + } + test("generalized linear regression: gaussian family against glm") { /* R code: @@ -485,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -502,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d42cb1714478..90ceb7dee38f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.ml.regression +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random +import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} + import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { + +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { import testImplicits._ @@ -636,6 +641,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val trainer = new LinearRegression + val model = trainer.fit(datasetWithDenseFeature) + + testPredictionModelSinglePrediction(model, datasetWithDenseFeature) + } + test("linear regression model with constant label") { /* R code: @@ -1045,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 8b8e8a655f47..e83c49f93297 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest{ +class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ @@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("prediction on single instance") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("Feature importance with toy data") { val rf = new RandomForestRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala new file mode 100644 index 000000000000..1312de3a1b52 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.commons.math3.distribution.{ExponentialDistribution, NormalDistribution, + RealDistribution, UniformRealDistribution} +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => Math3KSTest} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class KolmogorovSmirnovTestSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + def apacheCommonMath3EquivalenceTest( + sampleDist: RealDistribution, + theoreticalDist: RealDistribution, + theoreticalDistByName: (String, Array[Double]), + rejectNullHypothesis: Boolean): Unit = { + + // set seeds + val seed = 10L + sampleDist.reseedRandomGenerator(seed) + if (theoreticalDist != null) { + theoreticalDist.reseedRandomGenerator(seed) + } + + // Sample data from the distributions and parallelize it + val n = 100000 + val sampledArray = sampleDist.sample(n) + val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample") + + // Use a apache math commons local KS test to verify calculations + val ksTest = new Math3KSTest() + val pThreshold = 0.05 + + // Comparing a standard normal sample to a standard normal distribution + val Row(pValue1: Double, statistic1: Double) = + if (theoreticalDist != null) { + val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x) + KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head() + } else { + KolmogorovSmirnovTest.test(sampledDF, "sample", + theoreticalDistByName._1, + theoreticalDistByName._2: _* + ).head() + } + val theoreticalDistMath3 = if (theoreticalDist == null) { + assert(theoreticalDistByName._1 == "norm") + val params = theoreticalDistByName._2 + new NormalDistribution(params(0), params(1)) + } else { + theoreticalDist + } + val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(theoreticalDistMath3, sampledArray) + val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n) + // Verify vs apache math commons ks test + assert(statistic1 ~== referenceStat1 relTol 1e-4) + assert(pValue1 ~== referencePVal1 relTol 1e-4) + + if (rejectNullHypothesis) { + assert(pValue1 < pThreshold) + } else { + assert(pValue1 > pThreshold) + } + } + + test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") { + // Create theoretical distributions + val stdNormalDist = new NormalDistribution(0.0, 1.0) + val expDist = new ExponentialDistribution(0.6) + val uniformDist = new UniformRealDistribution(0.0, 1.0) + val expDist2 = new ExponentialDistribution(0.2) + val stdNormByName = Tuple2("norm", Array(0.0, 1.0)) + + apacheCommonMath3EquivalenceTest(stdNormalDist, null, stdNormByName, false) + apacheCommonMath3EquivalenceTest(expDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(uniformDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(expDist, expDist2, null, true) + } + + test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") { + /* + Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample + > sessionInfo() + R version 3.2.0 (2015-04-16) + Platform: x86_64-apple-darwin13.4.0 (64-bit) + > set.seed(20) + > v <- rnorm(20) + > v + [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612 + [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222 + [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264 + [19] -0.08578219 0.38921440 + > ks.test(v, pnorm, alternative = "two.sided") + + One-sample Kolmogorov-Smirnov test + + data: v + D = 0.18874, p-value = 0.4223 + alternative hypothesis: two-sided + */ + + val rKSStat = 0.18874 + val rKSPVal = 0.4223 + val rData = sc.parallelize( + Array( + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ) + ).toDF("sample") + val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest + .test(rData, "sample", "norm", 0, 1).head() + assert(statistic ~== rKSStat relTol 1e-4) + assert(pValue ~== rKSPVal relTol 1e-4) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931fb9..4dbbd75d2466 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + private val seed = 42 + ///////////////////////////////////////////////////////////////////////////// // Tests for split calculation ///////////////////////////////////////////////////////////////////////////// @@ -90,12 +93,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array.fill(200000)(math.random) + val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) assert(fakeMetadata.numSplits(0) === 5) @@ -106,7 +109,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -114,7 +117,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits <= numSplits { - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -122,7 +125,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { - val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -132,7 +135,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -147,7 +150,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -161,12 +164,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + .map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) @@ -174,12 +178,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits for constant feature { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits === Array.empty[Double]) @@ -320,10 +324,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -336,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -362,10 +366,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -378,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -407,7 +411,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42, instr = None, prune = false).head + model.rootNode match { case n: InternalNode => n.split match { case s: CategoricalSplit => @@ -577,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -613,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) - .asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], + numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance @@ -631,13 +636,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip Vectors.sparse(size, indices.toArray, values.toArray) } + + @tailrec + private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { + if (nodes.isEmpty) { + acc + } + else { + nodes.head match { + case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2..3f03d909d4a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { + def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + if (isClassification) { + new ClassificationInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], + split, parentImp) + } else { + new RegressionInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], + split, parentImp) + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade262709..e6ee7220d227 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c83..cd76acf9c67b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434e..4d9e664850c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 17678aa611a4..76d41f9b2371 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,9 +22,11 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.{PipelineModel, Transformer} -import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.ml.{PredictionModel, Transformer} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -62,8 +64,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val columnNames = dataframe.schema.fieldNames val stream = MemoryStream[A] - val streamDF = stream.toDS().toDF(columnNames: _*) - + val columnsWithMetadata = dataframe.schema.map { structField => + col(structField.name).as(structField.name, structField.metadata) + } + val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*) val data = dataframe.as[A].collect() val streamOutput = transformer.transform(streamDF) @@ -108,5 +112,39 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => otherResultCols: _*)(globalCheckFunction) testTransformerOnDF(dataframe, transformer, firstResultCol, otherResultCols: _*)(globalCheckFunction) + } + + def testTransformerByInterceptingException[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + expectedMessagePart : String, + firstResultCol: String) { + + def hasExpectedMessage(exception: Throwable): Boolean = + exception.getMessage.contains(expectedMessagePart) || + (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) + + withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { + val exceptionOnDf = intercept[Throwable] { + testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnDf)) + } + withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { + val exceptionOnStreamData = intercept[Throwable] { + testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnStreamData)) + } + } + + def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _], + dataset: Dataset[_]): Unit = { + + model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol) + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction === model.predict(features)) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index aef81c8c173a..5e72b4d864c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite { } } - def checkNumericTypesALS( - estimator: ALS, - spark: SparkSession, - column: String, - baseType: NumericType) - (check: (ALSModel, ALSModel) => Unit) - (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(spark, column) - val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) - actuals.foreach { case (_, actual) => check(expected, actual) } - actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } - - val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) - val cols = Seq(col(column).cast(StringType)) ++ others - val strDF = baseDF.select(cols: _*) - val thrown = intercept[IllegalArgumentException] { - estimator.fit(strDF) - } - assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) - } - def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } - def genRatingsDFWithNumericCols( - spark: SparkSession, - column: String): Map[NumericType, DataFrame] = { - val df = spark.createDataFrame(Seq( - (0, 10, 1.0), - (1, 20, 2.0), - (2, 30, 3.0), - (3, 40, 4.0), - (4, 50, 5.0) - )).toDF("user", "item", "rating") - - val others = df.columns.toSeq.diff(Seq(column)).map(col) - val types: Seq[NumericType] = - Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map { t => - val cols = Seq(col(column).cast(t)) ++ others - t -> df.select(cols: _*) - }.toMap - } - def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", @@ -291,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 000000000000..d2c4832b12ba --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.write.format("pmml").save(path) + intercept[IOException] { + instance.write.format("pmml").save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark")) + checkModelData(pmmlModel) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala similarity index 51% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala rename to mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala index f872e0f4b65d..dbdc69f95d84 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -14,25 +14,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.ml.util -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource /** - * A driver configuration step for mounting user-specified secrets onto user-specified paths. - * - * @param bootstrap a utility actually handling mounting of the secrets. + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). */ -private[spark] class DriverMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) - driverSpec.copy( - driverPod = pod, - driverContainer = container - ) +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala new file mode 100644 index 000000000000..f4c1f0bdb32c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.sql.{DataFrame, SparkSession} + +class FakeLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + +class FakeLinearRegressionWriterWithName extends MLFormatRegister { + override def format(): String = "fakeWithName" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + + +class DuplicateLinearRegressionWriter1 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class DuplicateLinearRegressionWriter2 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class ReadWriteSuite extends MLTest { + + import testImplicits._ + + private val seed: Int = 42 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0), + xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF() + } + + test("unsupported/non existent export formats") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + // Does not exist with a long class name + val thrownDNE = intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + assert(thrownDNE.getMessage(). + contains("Could not load requested format")) + + // Does not exist with a short name + val thrownDNEShort = intercept[SparkException] { + model.write.format("boop").save("boop") + } + assert(thrownDNEShort.getMessage(). + contains("Could not load requested format")) + + // Check with a valid class that is not a writer format. + val thrownInvalid = intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + assert(thrownInvalid.getMessage() + .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat")) + } + + test("invalid paths fail") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("pmml").save("") + } + assert(thrown.getMessage().contains("Can not create a Path from an empty string")) + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name") + } + assert(thrown.getMessage().contains("Fake writer doesn't write")) + val thrownWithName = intercept[Exception] { + model.write.format("fakeWithName").save("name") + } + assert(thrownWithName.getMessage().contains("Fake writer doesn't write")) + } + + test("duplicate format raises error") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("dupe").save("dupepanda") + } + assert(thrown.getMessage().contains("Multiple writers found for")) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 00d7e2f2d386..1b98250061c7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -89,7 +89,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters = km.initKMeansParallel(normedData).map(_.vector) + + val distanceMeasureInstance = new EuclideanDistanceMeasure + val initialCenters = km.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters.length === initialCenters.distinct.length) assert(initialCenters.length <= numDistinctPoints) @@ -104,7 +106,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector) + val initialCenters2 = km2.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters2.length === initialCenters2.distinct.length) assert(initialCenters2.length === k) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 4074bead421e..217b4a35438f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -495,4 +495,18 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV)) assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV)) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 441d0f7614bf..bc59f3f4125f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { - if (i < 1000) { + if (i < 1001) { arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) diff --git a/pom.xml b/pom.xml index 1b3716437646..88e77ff87474 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.1 + 1.4.3 nohive 1.6.0 9.3.20.v20170531 @@ -160,7 +160,7 @@ 1.9.13 2.6.7 2.6.7.1 - 1.1.2.6 + 1.1.7.1 1.1.2 1.2.0-incubating 1.10 @@ -185,6 +185,10 @@ 2.8 1.8 1.0.0 + 0.8.0 ${java.home} @@ -575,7 +579,7 @@ commons-net commons-net - 2.2 + 3.1 io.netty @@ -705,7 +709,13 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.11 + 3.5.3 + + + com.fasterxml.jackson.core + * + + org.scala-lang @@ -732,11 +742,6 @@ scala-parser-combinators_${scala.binary.version} 1.0.4 - - org.scala-lang - scalap - ${scala.version} - jline @@ -1735,10 +1740,6 @@ org.apache.hive hive-storage-api - - io.airlift - slice - @@ -1752,6 +1753,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-mapreduce-client-core + org.apache.orc orc-core @@ -2666,6 +2671,15 @@ + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 2ef0e7b40d94..adde213e361f 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,7 +88,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "2.0.0" + val previousSparkVersion = "2.2.0" val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813e..7d0e88ee20c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,8 +34,53 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.4.x + lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), + + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), + + // [SPARK-20659] Remove StorageStatus, or make it private + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ) + // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), @@ -95,7 +140,40 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-21728][CORE] Allow SparkSubmit to use Logging + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), + + // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), + + // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), + + // [SPARK-20643][CORE] Add listener implementation to collect app state + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), + + // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), + + // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json + // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), + + // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), + + // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") ) // Exclude rules for 2.2.x @@ -1079,6 +1157,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/python/.coveragerc old mode 100755 new mode 100644 similarity index 56% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh rename to python/.coveragerc index 82559889f4be..b3339cd356a6 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh +++ b/python/.coveragerc @@ -1,4 +1,3 @@ -#!/bin/bash # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -16,22 +15,7 @@ # limitations under the License. # -# echo commands to the terminal output -set -ex - -# Check whether there is a passwd entry for the container UID -myuid=$(id -u) -mygid=$(id -g) -uidentry=$(getent passwd $myuid) - -# If there is no passwd entry for the container UID, attempt to create one -if [ -z "$uidentry" ] ; then - if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd - else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" - fi -fi - -# Execute the container CMD under tini for better hygiene -/sbin/tini -s -- "$@" +[run] +branch = true +parallel = true +data_file = ${COVERAGE_DIR}/coverage_data/coverage diff --git a/python/README.md b/python/README.md index 3f17fdb98a08..2e0112da58b9 100644 --- a/python/README.md +++ b/python/README.md @@ -22,7 +22,7 @@ This packaging is currently experimental and may change in future versions (alth Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to set up your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). **NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 4d142c91629c..58218918693c 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,6 +54,7 @@ from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ +from pyspark._globals import _NoValue def since(version): diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py new file mode 100644 index 000000000000..8e6099db0996 --- /dev/null +++ b/python/pyspark/_globals.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if pyspark itself is reloaded. In particular, a function like the following +will still work correctly after pyspark is reloaded: + + def foo(arg=pyspark._NoValue): + if arg is pyspark._NoValue: + ... + +See gh-7844 for a discussion of the reload problem that motivated this module. + +Note that this approach is taken after from NumPy. +""" + +__ALL__ = ['_NoValue'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading pyspark._globals is not allowed') +_is_loaded = True + + +class _NoValueType(object): + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + deprecated keyword in order to check if it has been given a user defined + value. + + This class was copied from NumPy. + """ + __instance = None + + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super(_NoValueType, cls).__new__(cls) + return cls.__instance + + # needed for python 2 to preserve identity through a pickle + def __reduce__(self): + return (self.__class__, ()) + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc74..f730d290273f 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,7 +94,6 @@ else: import socketserver as SocketServer import threading -from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer @@ -266,4 +265,4 @@ def _start_update_server(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 02fc515fb824..b3dfc99962a3 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -162,4 +162,4 @@ def clear(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 40e91a2d0655..ea845b98b3db 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -57,7 +57,6 @@ import types import weakref -from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -181,6 +180,32 @@ def _builtin_type(name): return getattr(types, name) +def _make__new__factory(type_): + def _factory(): + return type_.__new__ + return _factory + + +# NOTE: These need to be module globals so that they're pickleable as globals. +_get_dict_new = _make__new__factory(dict) +_get_frozenset_new = _make__new__factory(frozenset) +_get_list_new = _make__new__factory(list) +_get_set_new = _make__new__factory(set) +_get_tuple_new = _make__new__factory(tuple) +_get_object_new = _make__new__factory(object) + +# Pre-defined set of builtin_function_or_method instances that can be +# serialized. +_BUILTIN_TYPE_CONSTRUCTORS = { + dict.__new__: _get_dict_new, + frozenset.__new__: _get_frozenset_new, + set.__new__: _get_set_new, + list.__new__: _get_list_new, + tuple.__new__: _get_tuple_new, + object.__new__: _get_object_new, +} + + if sys.version_info < (3, 4): def _walk_global_ops(code): """ @@ -237,28 +262,16 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - except pickle.PickleError: - raise - except Exception as e: - emsg = _exception_message(e) - if "'i' format requires" in emsg: - msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) - print_exec(sys.stderr) - raise pickle.PicklingError(msg) - + raise def save_memoryview(self, obj): - """Fallback to save_string""" - Pickler.save_string(self, str(obj)) + self.save(obj.tobytes()) + dispatch[memoryview] = save_memoryview - def save_buffer(self, obj): - """Fallback to save_string""" - Pickler.save_string(self,str(obj)) - if PY3: - dispatch[memoryview] = save_memoryview - else: + if not PY3: + def save_buffer(self, obj): + self.save(str(obj)) dispatch[buffer] = save_buffer def save_unsupported(self, obj): @@ -318,6 +331,24 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ + try: + should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS + except TypeError: + # Methods of builtin types aren't hashable in python 2. + should_special_case = False + + if should_special_case: + # We keep a special-cased cache of built-in type constructors at + # global scope, because these functions are structured very + # differently in different python versions and implementations (for + # example, they're instances of types.BuiltinFunctionType in + # CPython, but they're ordinary types.FunctionType instances in + # PyPy). + # + # If the function we've received is in that cache, we just + # serialize it as a lookup into the cache. + return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj) + write = self.write if name is None: @@ -344,7 +375,7 @@ def save_function(self, obj, name=None): return self.save_global(obj, name) # a builtin_function_or_method which comes in as an attribute of some - # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # object (e.g., itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. @@ -352,16 +383,13 @@ def save_function(self, obj, name=None): # for different python versions. if not hasattr(obj, '__code__'): if PY3: - if sys.version_info < (3, 4): - raise pickle.PicklingError("Can't pickle %r" % obj) - else: - rv = obj.__reduce_ex__(self.proto) + rv = obj.__reduce_ex__(self.proto) else: if hasattr(obj, '__self__'): rv = (getattr, (obj.__self__, name)) else: raise pickle.PicklingError("Can't pickle %r" % obj) - return Pickler.save_reduce(self, obj=obj, *rv) + return self.save_reduce(obj=obj, *rv) # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a @@ -420,20 +448,18 @@ def save_dynamic_class(self, obj): from global modules. """ clsdict = dict(obj.__dict__) # copy dict proxy to a dict - if not isinstance(clsdict.get('__dict__', None), property): - # don't extract dict that are properties - clsdict.pop('__dict__', None) - clsdict.pop('__weakref__', None) - - # hack as __new__ is stored differently in the __dict__ - new_override = clsdict.get('__new__', None) - if new_override: - clsdict['__new__'] = obj.__new__ - - # namedtuple is a special case for Spark where we use the _load_namedtuple function - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return + clsdict.pop('__weakref__', None) + + # On PyPy, __doc__ is a readonly attribute, so we need to include it in + # the initial skeleton class. This is safe because we know that the + # doc can't participate in a cycle with the original class. + type_kwargs = {'__doc__': clsdict.pop('__doc__', None)} + + # If type overrides __dict__ as a property, include it in the type kwargs. + # In Python 2, we can't set this attribute after construction. + __dict__ = clsdict.pop('__dict__', None) + if isinstance(__dict__, property): + type_kwargs['__dict__'] = __dict__ save = self.save write = self.write @@ -453,23 +479,12 @@ def save_dynamic_class(self, obj): # Push the rehydration function. save(_rehydrate_skeleton_class) - # Mark the start of the args for the rehydration function. + # Mark the start of the args tuple for the rehydration function. write(pickle.MARK) - # On PyPy, __doc__ is a readonly attribute, so we need to include it in - # the initial skeleton class. This is safe because we know that the - # doc can't participate in a cycle with the original class. - doc_dict = {'__doc__': clsdict.pop('__doc__', None)} - - # Create and memoize an empty class with obj's name and bases. - save(type(obj)) - save(( - obj.__name__, - obj.__bases__, - doc_dict, - )) - write(pickle.REDUCE) - self.memoize(obj) + # Create and memoize an skeleton class with obj's name and bases. + tp = type(obj) + self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -522,17 +537,22 @@ def save_function_tuple(self, func): self.memoize(func) # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(dct) - save(func.__module__) - save(closure_values) + state = { + 'globals': f_globals, + 'defaults': defaults, + 'dict': dct, + 'module': func.__module__, + 'closure_values': closure_values, + } + if hasattr(func, '__qualname__'): + state['qualname'] = func.__qualname__ + save(state) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple _extract_code_globals_cache = ( weakref.WeakKeyDictionary() - if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") + if not hasattr(sys, "pypy_version_info") else {}) @classmethod @@ -608,37 +628,22 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = '__main__' + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) - if modname == '__main__': - themodule = None - else: - __import__(modname) - themodule = sys.modules[modname] - self.modules.add(themodule) + try: + return Pickler.save_global(self, obj, name=name) + except Exception: + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - if hasattr(themodule, name) and getattr(themodule, name) is obj: - return Pickler.save_global(self, obj, name) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + return self.save_dynamic_class(obj) - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - self.save_dynamic_class(obj) - else: - raise pickle.PicklingError("Can't pickle %r" % obj) + raise dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -709,12 +714,7 @@ def save_property(self, obj): dispatch[property] = save_property def save_classmethod(self, obj): - try: - orig_func = obj.__func__ - except AttributeError: # Python 2.6 - orig_func = obj.__get__(None, object) - if isinstance(obj, classmethod): - orig_func = orig_func.__func__ # Unbind + orig_func = obj.__func__ self.save_reduce(type(obj), (orig_func,), obj=obj) dispatch[classmethod] = save_classmethod dispatch[staticmethod] = save_classmethod @@ -754,64 +754,6 @@ def __getattribute__(self, item): if type(operator.attrgetter) is type: dispatch[operator.attrgetter] = save_attrgetter - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): - # Assert that args is a tuple or None - if not isinstance(args, tuple): - raise pickle.PicklingError("args from reduce() should be a tuple") - - # Assert that func is callable - if not hasattr(func, '__call__'): - raise pickle.PicklingError("func from reduce should be callable") - - save = self.save - write = self.write - - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - save(state) - write(pickle.BUILD) - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): # 2.7 supports partial pickling - dispatch[partial] = save_partial - - def save_file(self, obj): """Save a file""" try: @@ -867,23 +809,21 @@ def save_not_implemented(self, obj): dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented - # WeakSet was added in 2.7. - if hasattr(weakref, 'WeakSet'): - def save_weakset(self, obj): - self.save_reduce(weakref.WeakSet, (list(obj),)) - - dispatch[weakref.WeakSet] = save_weakset + def save_weakset(self, obj): + self.save_reduce(weakref.WeakSet, (list(obj),)) - """Special functions for Add-on libraries""" - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - pass + dispatch[weakref.WeakSet] = save_weakset def save_logger(self, obj): self.save_reduce(logging.getLogger, (obj.name,), obj=obj) dispatch[logging.Logger] = save_logger + """Special functions for Add-on libraries""" + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + pass + # Tornado support @@ -913,11 +853,12 @@ def dump(obj, file, protocol=2): def dumps(obj, protocol=2): file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - return file.getvalue() + try: + cp = CloudPickler(file,protocol) + cp.dump(obj) + return file.getvalue() + finally: + file.close() # including pickles unloading functions in this namespace load = pickle.load @@ -1019,18 +960,40 @@ def __reduce__(cls): return cls.__name__ -def _fill_function(func, globals, defaults, dict, module, closure_values): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). +def _fill_function(*args): + """Fills in the rest of function data into the skeleton function object + + The skeleton itself is create by _make_skel_func(). """ - func.__globals__.update(globals) - func.__defaults__ = defaults - func.__dict__ = dict - func.__module__ = module + if len(args) == 2: + func = args[0] + state = args[1] + elif len(args) == 5: + # Backwards compat for cloudpickle v0.4.0, after which the `module` + # argument was introduced + func = args[0] + keys = ['globals', 'defaults', 'dict', 'closure_values'] + state = dict(zip(keys, args[1:])) + elif len(args) == 6: + # Backwards compat for cloudpickle v0.4.1, after which the function + # state was passed as a dict to the _fill_function it-self. + func = args[0] + keys = ['globals', 'defaults', 'dict', 'module', 'closure_values'] + state = dict(zip(keys, args[1:])) + else: + raise ValueError('Unexpected _fill_value arguments: %r' % (args,)) + + func.__globals__.update(state['globals']) + func.__defaults__ = state['defaults'] + func.__dict__ = state['dict'] + if 'module' in state: + func.__module__ = state['module'] + if 'qualname' in state: + func.__qualname__ = state['qualname'] cells = func.__closure__ if cells is not None: - for cell, value in zip(cells, closure_values): + for cell, value in zip(cells, state['closure_values']): if value is not _empty_cell_value: cell_set(cell, value) @@ -1087,13 +1050,6 @@ def _find_module(mod_name): file.close() return path, description -def _load_namedtuple(name, fields): - """ - Loads a class generated by namedtuple - """ - from collections import namedtuple - return namedtuple(name, fields) - """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 491b3a81972b..ab429d9ab10d 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -217,7 +217,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 24905f1c97b2..7c664966ed74 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1035,7 +1035,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7f06d4288c87..7bed5216eabf 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -89,7 +89,7 @@ def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) - exit(code) + sys.exit(code) def handle_sigterm(*args): shutdown(1) diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 212a618b767a..9cf0e8c8d2fe 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -68,7 +68,7 @@ def is_spark_home(path): return next(path for path in paths if is_spark_home(path)) except StopIteration: print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) - exit(-1) + sys.exit(-1) if __name__ == "__main__": print(_find_spark_home()) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index b27e91a4cc25..6af084adcf37 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -884,6 +884,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": import doctest + import sys (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d..ec17653a1adf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,6 +16,7 @@ # import operator +import sys from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only @@ -1542,12 +1543,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction") + rawPredictionCol="rawPrediction") """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -1561,12 +1562,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): Sets params for MultilayerPerceptronClassifier. """ kwargs = self._input_kwargs @@ -2043,4 +2044,4 @@ def _to_java(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 66fb00508522..b3d5fb17f6b8 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper @@ -403,17 +405,23 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) + self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, + distanceMeasure="euclidean") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -423,10 +431,12 @@ def _create_model(self, java_model): @keyword_only @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") Sets params for KMeans. """ @@ -475,6 +485,20 @@ def getInitSteps(self): """ return self.getOrDefault(self.initSteps) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ @@ -1159,4 +1183,4 @@ def getKeepLastCheckpoint(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index aa8dbe708a11..8eaf07645a37 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys from abc import abstractmethod, ABCMeta from pyspark import since, keyword_only @@ -334,7 +335,13 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, .. note:: Experimental Evaluator for Clustering results, which expects two input - columns: prediction and features. + columns: prediction and features. The metric computes the Silhouette + measure using the squared Euclidean distance. + + The Silhouette is a measure for the validation of the consistency + within clusters. It ranges between 1 and -1, where a value close to + 1 means that the points in a cluster are close to the other points + in the same cluster and far from the points of the other clusters. >>> from pyspark.ml.linalg import Vectors >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), @@ -356,18 +363,21 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (silhouette)", typeConverter=TypeConverters.toString) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'squaredEuclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ __init__(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") """ super(ClusteringEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) - self._setDefault(metricName="silhouette") + self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean") kwargs = self._input_kwargs self._set(**kwargs) @@ -388,15 +398,30 @@ def getMetricName(self): @keyword_only @since("2.3.0") def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") Sets params for clustering evaluator. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + + if __name__ == "__main__": import doctest import tempfile @@ -422,4 +447,4 @@ def setParams(self, predictionCol="prediction", featuresCol="features", except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 13bf95cce40b..cdda30cfab48 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,12 +19,12 @@ if sys.version > '3': basestring = str -from pyspark import since, keyword_only +from pyspark import since, keyword_only, SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', @@ -45,6 +45,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', + 'OneHotEncoderEstimator', 'OneHotEncoderModel', 'PCA', 'PCAModel', 'PolynomialExpansion', 'QuantileDiscretizer', @@ -402,8 +403,84 @@ def getSplits(self): return self.getOrDefault(self.splits) +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. + """ + + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0", typeConverter=TypeConverters.toFloat) + maxDF = Param( + Params._dummy(), "maxDF", "Specifies the maximum number of" + + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an" + + " integer >= 1, this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum" + + " fraction of documents the term could appear in." + + " Default (2^63) - 1", typeConverter=TypeConverters.toFloat) + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) + + def __init__(self, *args): + super(_CountVectorizerParams, self).__init__(*args) + self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False) + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("2.4.0") + def getMaxDF(self): + """ + Gets the value of maxDF or its default value. + """ + return self.getOrDefault(self.maxDF) + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + + @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. @@ -436,54 +513,40 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, >>> loadedModel = CountVectorizerModel.load(modelPath) >>> loadedModel.vocabulary == model.vocabulary True + >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], + ... inputCol="raw", outputCol="vectors") + >>> fromVocabModel.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... .. versionadded:: 1.6.0 """ - minTF = Param( - Params._dummy(), "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then this " + - "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", - typeConverter=TypeConverters.toFloat) - minDF = Param( - Params._dummy(), "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0", typeConverter=TypeConverters.toFloat) - vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", - typeConverter=TypeConverters.toInt) - binary = Param( - Params._dummy(), "binary", "Binary toggle to control the output vector values." + - " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + - " for discrete probabilistic models that model binary events rather than integer counts." + - " Default False", typeConverter=TypeConverters.toBoolean) - @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None,outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None, outputCol=None) Set the params for the CountVectorizer """ kwargs = self._input_kwargs @@ -496,13 +559,6 @@ def setMinTF(self, value): """ return self._set(minTF=value) - @since("1.6.0") - def getMinTF(self): - """ - Gets the value of minTF or its default value. - """ - return self.getOrDefault(self.minTF) - @since("1.6.0") def setMinDF(self, value): """ @@ -510,12 +566,12 @@ def setMinDF(self, value): """ return self._set(minDF=value) - @since("1.6.0") - def getMinDF(self): + @since("2.4.0") + def setMaxDF(self, value): """ - Gets the value of minDF or its default value. + Sets the value of :py:attr:`maxDF`. """ - return self.getOrDefault(self.minDF) + return self._set(maxDF=value) @since("1.6.0") def setVocabSize(self, value): @@ -524,13 +580,6 @@ def setVocabSize(self, value): """ return self._set(vocabSize=value) - @since("1.6.0") - def getVocabSize(self): - """ - Gets the value of vocabSize or its default value. - """ - return self.getOrDefault(self.vocabSize) - @since("2.0.0") def setBinary(self, value): """ @@ -538,24 +587,40 @@ def setBinary(self, value): """ return self._set(binary=value) - @since("2.0.0") - def getBinary(self): - """ - Gets the value of binary or its default value. - """ - return self.getOrDefault(self.binary) - def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): +@inherit_doc +class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ + @classmethod + @since("2.4.0") + def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): + """ + Construct the model directly from a vocabulary list of strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) + model = CountVectorizerModel._create_from_java_class( + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if minTF is not None: + model.setMinTF(minTF) + if binary is not None: + model.setBinary(binary) + model._set(vocabSize=len(vocabulary)) + return model + @property @since("1.6.0") def vocabulary(self): @@ -564,6 +629,20 @@ def vocabulary(self): """ return self._call_java("vocabulary") + @since("2.4.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("2.4.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -740,9 +819,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> df = spark.createDataFrame(data, cols) >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasher.setCategoricalCols(["real"]).transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) + SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) @@ -1577,6 +1656,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories. The output vectors are sparse. + .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to + :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0. + .. seealso:: :py:class:`StringIndexer` for converting categorical values into @@ -1641,6 +1723,118 @@ def getDropLast(self): return self.getOrDefault(self.dropLast) +@inherit_doc +class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + A one-hot encoder that maps a column of category indices to a column of binary vectors, with + at most a single one-value per row that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to an output vector of + `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via `dropLast`), + because it makes the vector entries sum up to one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + + Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories. + The output vectors are sparse. + + When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + vector. + + Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output + cols come in pairs, specified by the order in the arrays, and each pair is treated + independently. + + See `StringIndexer` for converting categorical values into category indices + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"]) + >>> model = ohe.fit(df) + >>> model.transform(df).head().output + SparseVector(2, {0: 1.0}) + >>> ohePath = temp_path + "/oheEstimator" + >>> ohe.save(ohePath) + >>> loadedOHE = OneHotEncoderEstimator.load(ohePath) + >>> loadedOHE.getInputCols() == ohe.getInputCols() + True + >>> modelPath = temp_path + "/ohe-model" + >>> model.save(modelPath) + >>> loadedModel = OneHotEncoderModel.load(modelPath) + >>> loadedModel.categorySizes == model.categorySizes + True + + .. versionadded:: 2.3.0 + """ + + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " + + "transform(). Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error). Note that this Param " + + "is only used during transform; during fitting, invalid data will " + + "result in an error.", + typeConverter=TypeConverters.toString) + + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + """ + super(OneHotEncoderEstimator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid) + self._setDefault(handleInvalid="error", dropLast=True) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + Sets params for this OneHotEncoderEstimator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + return self._set(dropLast=value) + + @since("2.3.0") + def getDropLast(self): + """ + Gets the value of dropLast or its default value. + """ + return self.getOrDefault(self.dropLast) + + def _create_model(self, java_model): + return OneHotEncoderModel(java_model) + + +class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`OneHotEncoderEstimator`. + + .. versionadded:: 2.3.0 + """ + + @property + @since("2.3.0") + def categorySizes(self): + """ + Original number of categories for each feature being encoded. + The array contains one value for each input column, in order. + """ + return self._call_java("categorySizes") + + @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -2148,9 +2342,38 @@ def mean(self): return self._call_java("mean") +class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`. + """ + + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "or NULL values) in features and label column of string type. " + + "Options are 'skip' (filter out rows with invalid data), " + + "error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + + def __init__(self, *args): + super(_StringIndexerParams, self).__init__(*args) + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + + @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2194,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] + >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"], + ... inputCol="label", outputCol="indexed", handleInvalid="error") + >>> result = fromlabelsModel.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)] .. versionadded:: 1.4.0 """ - stringOrderType = Param(Params._dummy(), "stringOrderType", - "How to order labels of string column. The first label after " + - "ordering is assigned an index of 0. Supported options: " + - "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", - typeConverter=TypeConverters.toString) - - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + - "or NULL values) in features and label column of string type. " + - "Options are 'skip' (filter out rows with invalid data), " + - "error (throw an error), or 'keep' (put invalid data " + - "in a special additional bucket, at index numLabels).", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2220,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2246,21 +2461,33 @@ def setStringOrderType(self, value): """ return self._set(stringOrderType=value) - @since("2.3.0") - def getStringOrderType(self): - """ - Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. - """ - return self.getOrDefault(self.stringOrderType) - -class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): +class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """ + @classmethod + @since("2.4.0") + def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None): + """ + Construct the model directly from an array of label strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jlabels = StringIndexerModel._new_java_array(labels, java_class) + model = StringIndexerModel._create_from_java_class( + "org.apache.spark.ml.feature.StringIndexerModel", jlabels) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if handleInvalid is not None: + model.setHandleInvalid(handleInvalid) + return model + @property @since("1.5.0") def labels(self): @@ -2269,6 +2496,13 @@ def labels(self): """ return self._call_java("labels") + @since("2.4.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -2467,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ A feature transformer that merges multiple columns into a vector column. @@ -2485,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath) >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs True + >>> dfWithNullsAndNaNs = spark.createDataFrame( + ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"]) + >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features", + ... handleInvalid="keep") + >>> vecAssembler2.transform(dfWithNullsAndNaNs).show() + +---+---+----+-------------+ + | a| b| c| features| + +---+---+----+-------------+ + |1.0|2.0|null|[1.0,2.0,NaN]| + |3.0|NaN| 4.0|[3.0,NaN,4.0]| + |5.0|6.0| 7.0|[5.0,6.0,7.0]| + +---+---+----+-------------+ + ... + >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show() + +---+---+---+-------------+ + | a| b| c| features| + +---+---+---+-------------+ + |5.0|6.0|7.0|[5.0,6.0,7.0]| + +---+---+---+-------------+ + ... .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " + + "and NaN values). Options are 'skip' (filter out rows with invalid " + + "data), 'error' (throw an error), or 'keep' (return relevant number " + + "of NaN in the output). Column lengths are taken from the size of ML " + + "Attribute Group, which can be set using `VectorSizeHint` in a " + + "pipeline before `VectorAssembler`. Column lengths can also be " + + "inferred from first rows of the data since it is safe to do so but " + + "only in case of 'error' or 'skip').", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCols=None, outputCol=None): + def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCols=None, outputCol=None) + __init__(self, inputCols=None, outputCol=None, handleInvalid="error") """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) + self._setDefault(handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCols=None, outputCol=None): + def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCols=None, outputCol=None) + setParams(self, inputCols=None, outputCol=None, handleInvalid="error") Sets params for this VectorAssembler. """ kwargs = self._input_kwargs @@ -3324,7 +3590,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja selectorType = Param(Params._dummy(), "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: numTopFeatures (default), percentile and fpr.", + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.", typeConverter=TypeConverters.toString) numTopFeatures = \ @@ -3601,4 +3867,4 @@ def setSize(self, value): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index dd7dda5f0312..b8dafd49d354 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -144,7 +144,7 @@ def freqItemsets(self): @since("2.2.0") def associationRules(self): """ - Data with three columns: + DataFrame with three columns: * `antecedent` - Array of the same type as the input column. * `consequent` - Array of the same type as the input column. * `confidence` - Confidence for the rule (`DoubleType`). diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index c9b840276f67..96d702f84483 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -24,6 +24,8 @@ :members: """ +import sys + import numpy as np from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string @@ -40,6 +42,7 @@ class _ImageSchema(object): def __init__(self): self._imageSchema = None self._ocvTypes = None + self._columnSchema = None self._imageFields = None self._undefinedImageType = None @@ -49,7 +52,7 @@ def imageSchema(self): Returns the image schema. :return: a :class:`StructType` with a single column of images - named "image" (nullable). + named "image" (nullable) and having the same type returned by :meth:`columnSchema`. .. versionadded:: 2.3.0 """ @@ -75,6 +78,23 @@ def ocvTypes(self): self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) return self._ocvTypes + @property + def columnSchema(self): + """ + Returns the schema for the image column. + + :return: a :class:`StructType` for image column, + ``struct``. + + .. versionadded:: 2.4.0 + """ + + if self._columnSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.columnSchema() + self._columnSchema = _parse_datatype_json_string(jschema.json()) + return self._columnSchema + @property def imageFields(self): """ @@ -194,9 +214,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) >>> df.count() - 4 + 5 .. versionadded:: 2.3.0 """ @@ -216,3 +236,25 @@ def readImages(self, path, recursive=False, numPartitions=-1, def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance + + +def _test(): + import doctest + import pyspark.ml.image + globs = pyspark.ml.image.__dict__.copy() + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.image tests")\ + .getOrCreate() + globs['spark'] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.image, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index ad1b487676fa..6a611a2b5b59 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1158,7 +1158,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index d55d209d0939..6e9e0a34cdfd 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -119,10 +119,12 @@ def get$Name(self): ("inputCol", "input column name.", None, "TypeConverters.toString"), ("inputCols", "input column names.", None, "TypeConverters.toListString"), ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("outputCols", "output column names.", None, "TypeConverters.toListString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, - "TypeConverters.toInt"), + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + + "this setting will be ignored if the checkpoint directory is not set in the SparkContext.", + None, "TypeConverters.toInt"), ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None, "TypeConverters.toFloat"), @@ -155,6 +157,11 @@ def get$Name(self): "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", "1", "TypeConverters.toInt"), + ("collectSubModels", "Param for whether to collect a list of sub-models trained during " + + "tuning. If set to false, then only the single best sub-model will be available after " + + "fitting. If set to true, then all sub-models will be available. Warning: For large " + + "models, collecting all sub-models can cause OOMs on the Spark driver.", + "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index e5c5ddfba6c1..08408ee8fbfc 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -256,6 +256,29 @@ def getOutputCol(self): return self.getOrDefault(self.outputCol) +class HasOutputCols(Params): + """ + Mixin for param outputCols: output column names. + """ + + outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString) + + def __init__(self): + super(HasOutputCols, self).__init__() + + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def getOutputCols(self): + """ + Gets the value of outputCols or its default value. + """ + return self.getOrDefault(self.outputCols) + + class HasNumFeatures(Params): """ Mixin for param numFeatures: number of features. @@ -281,10 +304,10 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. """ - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() @@ -632,6 +655,30 @@ def getParallelism(self): return self.getOrDefault(self.parallelism) +class HasCollectSubModels(Params): + """ + Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. + """ + + collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean) + + def __init__(self): + super(HasCollectSubModels, self).__init__() + self._setDefault(collectSubModels=False) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + + def getCollectSubModels(self): + """ + Gets the value of collectSubModels or its default value. + """ + return self.getOrDefault(self.collectSubModels) + + class HasLoss(Params): """ Mixin for param loss: the loss function to be optimized. diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e8bcbe4cd34c..a8eae9bd268d 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -480,4 +482,4 @@ def recommendForItemSubset(self, dataset, numUsers): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f0812bd1d4a3..9a66d87d7f21 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since, keyword_only @@ -335,10 +336,10 @@ def rootMeanSquaredError(self): @since("2.0.0") def r2(self): """ - Returns R^2^, the coefficient of determination. + Returns R^2, the coefficient of determination. .. seealso:: `Wikipedia coefficient of determination \ - ` + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -346,6 +347,20 @@ def r2(self): """ return self._call_java("r2") + @property + @since("2.4.0") + def r2adj(self): + """ + Returns Adjusted R^2, the adjusted coefficient of determination. + + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ + `_ + + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark versions. + """ + return self._call_java("r2adj") + @property @since("2.0.0") def residuals(self): @@ -1812,4 +1827,4 @@ def __repr__(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 079b0833e1c6..a06ab31a7a56 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -15,9 +15,13 @@ # limitations under the License. # +import sys + from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.wrapper import _jvm +from pyspark.ml.wrapper import JavaWrapper, _jvm +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.functions import lit class ChiSquareTest(object): @@ -30,32 +34,6 @@ class ChiSquareTest(object): The null hypothesis is that the occurrence of the outcomes is statistically independent. - :param dataset: - DataFrame of categorical labels and categorical features. - Real-valued features will be treated as categorical for each distinct value. - :param featuresCol: - Name of features column in dataset, of type `Vector` (`VectorUDT`). - :param labelCol: - Name of label column in dataset, of any numerical type. - :return: - DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: - - `pValues: Vector` - - `degreesOfFreedom: Array[Int]` - - `statistics: Vector` - Each of these fields has one value per feature. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import ChiSquareTest - >>> dataset = [[0, Vectors.dense([0, 0, 1])], - ... [0, Vectors.dense([1, 0, 1])], - ... [1, Vectors.dense([2, 1, 1])], - ... [1, Vectors.dense([3, 1, 1])]] - >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) - >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') - >>> chiSqResult.select("degreesOfFreedom").collect()[0] - Row(degreesOfFreedom=[3, 1, 0]) - .. versionadded:: 2.2.0 """ @@ -64,6 +42,32 @@ class ChiSquareTest(object): def test(dataset, featuresCol, labelCol): """ Perform a Pearson's independence test using dataset. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest @@ -83,40 +87,6 @@ class Correlation(object): which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` to avoid recomputing the common lineage. - :param dataset: - A dataset or a dataframe. - :param column: - The name of the column of vectors for which the correlation coefficient needs - to be computed. This must be a column of the dataset, and it must contain - Vector objects. - :param method: - String specifying the method to use for computing correlation. - Supported: `pearson` (default), `spearman`. - :return: - A dataframe that contains the correlation matrix of the column of vectors. This - dataframe contains a single row and a single column of name - '$METHODNAME($COLUMN)'. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import Correlation - >>> dataset = [[Vectors.dense([1, 0, 0, -2])], - ... [Vectors.dense([4, 5, 0, 3])], - ... [Vectors.dense([6, 7, 0, 8])], - ... [Vectors.dense([9, 0, 0, 1])]] - >>> dataset = spark.createDataFrame(dataset, ['features']) - >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] - >>> print(str(pearsonCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], - [ 0.0556..., 1. , NaN, 0.9135...], - [ NaN, NaN, 1. , NaN], - [ 0.4004..., 0.9135..., NaN, 1. ]]) - >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] - >>> print(str(spearmanCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], - [ 0.1054..., 1. , NaN, 0.9486... ], - [ NaN, NaN, 1. , NaN], - [ 0.4 , 0.9486... , NaN, 1. ]]) - .. versionadded:: 2.2.0 """ @@ -125,6 +95,40 @@ class Correlation(object): def corr(dataset, column, method="pearson"): """ Compute the correlation matrix with specified method using dataset. + + :param dataset: + A Dataset or a DataFrame. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A DataFrame that contains the correlation matrix of the column of vectors. This + DataFrame contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) """ sc = SparkContext._active_spark_context javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation @@ -132,6 +136,256 @@ def corr(dataset, column, method="pearson"): return _java2py(sc, javaCorrObj.corr(*args)) +class KolmogorovSmirnovTest(object): + """ + .. note:: Experimental + + Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous + distribution. + + By comparing the largest difference between the empirical cumulative + distribution of the sample data and the theoretical distribution we can provide a test for the + the null hypothesis that the sample data comes from that theoretical distribution. + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def test(dataset, sampleCol, distName, *params): + """ + Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution + equality. Currently supports the normal distribution, taking as parameters the mean and + standard deviation. + + :param dataset: + a Dataset or a DataFrame containing the sample of data to test. + :param sampleCol: + Name of sample column in dataset, of any numerical type. + :param distName: + a `string` name for a theoretical distribution, currently only support "norm". + :param params: + a list of `Double` values specifying the parameters to be used for the theoretical + distribution. For "norm" distribution, the parameters includes mean and variance. + :return: + A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data. + This DataFrame will contain a single Row with the following fields: + - `pValue: Double` + - `statistic: Double` + + >>> from pyspark.ml.stat import KolmogorovSmirnovTest + >>> dataset = [[-1.0], [0.0], [1.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + >>> dataset = [[2.0], [3.0], [4.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest + dataset = _py2java(sc, dataset) + params = [float(param) for param in params] + return _java2py(sc, javaTestObj.test(dataset, sampleCol, distName, + _jvm().PythonUtils.toSeq(params))) + + +class Summarizer(object): + """ + .. note:: Experimental + + Tools for vectorized statistics on MLlib Vectors. + The methods in this package provide various statistics for Vectors contained inside DataFrames. + This class lets users pick the statistics they would like to extract for a given column. + + >>> from pyspark.ml.stat import Summarizer + >>> from pyspark.sql import Row + >>> from pyspark.ml.linalg import Vectors + >>> summarizer = Summarizer.metrics("mean", "count") + >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + +-----------------------------------+ + |aggregate_metrics(features, weight)| + +-----------------------------------+ + |[[1.0,1.0,1.0], 1] | + +-----------------------------------+ + + >>> df.select(summarizer.summary(df.features)).show(truncate=False) + +--------------------------------+ + |aggregate_metrics(features, 1.0)| + +--------------------------------+ + |[[1.0,1.5,2.0], 2] | + +--------------------------------+ + + >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.0,1.0] | + +--------------+ + + >>> df.select(Summarizer.mean(df.features)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.5,2.0] | + +--------------+ + + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def mean(col, weightCol=None): + """ + return a column of mean summary + """ + return Summarizer._get_single_metric(col, weightCol, "mean") + + @staticmethod + @since("2.4.0") + def variance(col, weightCol=None): + """ + return a column of variance summary + """ + return Summarizer._get_single_metric(col, weightCol, "variance") + + @staticmethod + @since("2.4.0") + def count(col, weightCol=None): + """ + return a column of count summary + """ + return Summarizer._get_single_metric(col, weightCol, "count") + + @staticmethod + @since("2.4.0") + def numNonZeros(col, weightCol=None): + """ + return a column of numNonZero summary + """ + return Summarizer._get_single_metric(col, weightCol, "numNonZeros") + + @staticmethod + @since("2.4.0") + def max(col, weightCol=None): + """ + return a column of max summary + """ + return Summarizer._get_single_metric(col, weightCol, "max") + + @staticmethod + @since("2.4.0") + def min(col, weightCol=None): + """ + return a column of min summary + """ + return Summarizer._get_single_metric(col, weightCol, "min") + + @staticmethod + @since("2.4.0") + def normL1(col, weightCol=None): + """ + return a column of normL1 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL1") + + @staticmethod + @since("2.4.0") + def normL2(col, weightCol=None): + """ + return a column of normL2 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL2") + + @staticmethod + def _check_param(featuresCol, weightCol): + if weightCol is None: + weightCol = lit(1.0) + if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column): + raise TypeError("featureCol and weightCol should be a Column") + return featuresCol, weightCol + + @staticmethod + def _get_single_metric(col, weightCol, metric): + col, weightCol = Summarizer._check_param(col, weightCol) + return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric, + col._jc, weightCol._jc)) + + @staticmethod + @since("2.4.0") + def metrics(*metrics): + """ + Given a list of metrics, provides a builder that it turns computes metrics from a column. + + See the documentation of [[Summarizer]] for an example. + + The following metrics are accepted (case sensitive): + - mean: a vector that contains the coefficient-wise mean. + - variance: a vector tha contains the coefficient-wise variance. + - count: the count of all vectors seen. + - numNonzeros: a vector with the number of non-zeros for each coefficients + - max: the maximum for each coefficient. + - min: the minimum for each coefficient. + - normL2: the Euclidian norm for each coefficient. + - normL1: the L1 norm of each coefficient (sum of the absolute values). + + :param metrics: + metrics that can be provided. + :return: + an object of :py:class:`pyspark.ml.stat.SummaryBuilder` + + Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + interface. + """ + sc = SparkContext._active_spark_context + js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics", + _to_seq(sc, metrics)) + return SummaryBuilder(js) + + +class SummaryBuilder(JavaWrapper): + """ + .. note:: Experimental + + A builder object that provides summary statistics about a given column. + + Users should not directly create such builders, but instead use one of the methods in + :py:class:`pyspark.ml.stat.Summarizer` + + .. versionadded:: 2.4.0 + + """ + def __init__(self, jSummaryBuilder): + super(SummaryBuilder, self).__init__(jSummaryBuilder) + + @since("2.4.0") + def summary(self, featuresCol, weightCol=None): + """ + Returns an aggregate object that contains the summary of the column with the requested + metrics. + + :param featuresCol: + a column that contains features Vector object. + :param weightCol: + a column that contains weight value. Default weight is 1.0. + :return: + an aggregate column that contains the statistics. The exact content of this + structure is determined during the creation of the builder. + """ + featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol) + return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat @@ -151,4 +405,4 @@ def corr(dataset, column, method="pearson"): failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900..093593132e56 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -51,7 +51,7 @@ from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ +from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. @@ -330,7 +369,7 @@ def test_property(self): raise RuntimeError("Test property to raise error when invoked") -class ParamTests(PySparkTestCase): +class ParamTests(SparkSessionTestCase): def test_copy_new_parent(self): testParams = TestParams() @@ -418,6 +457,9 @@ def test_kmeans_param(self): self.assertEqual(algo.getK(), 10) algo.setInitSteps(10) self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") def test_hasseed(self): noSeedSpecd = TestParams() @@ -472,6 +514,24 @@ def test_logistic_regression_check_thresholds(self): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + @staticmethod def check_params(test_self, py_stage, check_params_exist=True): """ @@ -538,6 +598,15 @@ def test_java_params(self): self.assertEqual(evaluator._java_obj.getMetricName(), "r2") self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + class FeatureTests(SparkSessionTestCase): @@ -628,6 +697,59 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + def test_rformula_force_index_label(self): df = self.spark.createDataFrame([ (1.0, 1.0, "a"), @@ -678,6 +800,43 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + class HasInducedError(Params): @@ -859,6 +1018,50 @@ def test_parallel_evaluation(self): cvParallelModel = cv.fit(dataset) self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -1027,6 +1230,40 @@ def test_parallel_evaluation(self): tvsParallelModel = tvs.fit(dataset) self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 @@ -1437,6 +1674,7 @@ def test_linear_regression_summary(self): self.assertAlmostEqual(s.meanSquaredError, 0.0) self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) self.assertTrue(isinstance(s.residuals, DataFrame)) self.assertEqual(s.numInstances, 2) self.assertEqual(s.degreesOfFreedom, 1) @@ -1620,6 +1858,21 @@ def test_kmeans_summary(self): self.assertEqual(s.k, 2) +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + class OneVsRestTests(SparkSessionTestCase): def test_copy(self): @@ -1852,6 +2105,7 @@ def test_read_images(self): self.assertEqual(len(array), first_row[1]) self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} self.assertEqual(ImageSchema.ocvTypes, expected) expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] @@ -1882,17 +2136,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -1942,18 +2202,28 @@ def test_java_params(self): import pyspark.ml.feature import pyspark.ml.classification import pyspark.ml.clustering + import pyspark.ml.evaluation import pyspark.ml.pipeline import pyspark.ml.recommendation import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and issubclass(cls, JavaParams)\ - and not inspect.isabstract(cls): + if not name.endswith('Model') and not name.endswith('Params')\ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector): @@ -2398,6 +2668,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6c0cad6cbaaa..0c8029f293cf 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,14 +15,16 @@ # limitations under the License. # import itertools -import numpy as np +import sys from multiprocessing.pool import ThreadPool +import numpy as np + from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasParallelism, HasSeed +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -31,7 +33,7 @@ 'TrainValidationSplitModel'] -def _parallelFitTasks(est, train, eva, validation, epm): +def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -41,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm): :param eva: Evaluator, used to compute `metric` :param validation: DataFrame, validation data set, used for evaluation. :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. - :return: (int, float), an index into `epm` and the associated metric value. + :param collectSubModel: Whether to collect sub model. + :return: (int, float, subModel), an index into `epm` and the associated metric value. """ modelIter = est.fitMultiple(train, epm) def singleTask(): index, model = next(modelIter) metric = eva.evaluate(model.transform(validation, epm[index])) - return index, metric + return index, metric, model if collectSubModel else None return [singleTask] * len(epm) @@ -192,7 +195,8 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -231,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1) + seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) @@ -244,10 +248,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -280,6 +284,10 @@ def _fit(self, dataset): metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h @@ -288,9 +296,12 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) + if collectSubModelsParam: + subModels[i][j] = subModel + validation.unpersist() train.unpersist() @@ -299,7 +310,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics)) + return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels)) @since("1.4.0") def copy(self, extra=None): @@ -343,9 +354,11 @@ def _from_java(cls, java_stage): numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed, parallelism=parallelism) + numFolds=numFolds, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -365,6 +378,7 @@ def _to_java(self): _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -379,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[]): + def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics + #: sub model list from cross validation + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -397,6 +413,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -405,7 +422,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics - return CrossValidatorModel(bestModel, avgMetrics) + subModels = self.subModels + return CrossValidatorModel(bestModel, avgMetrics, subModels) @since("2.3.0") def write(self): @@ -424,13 +442,17 @@ def _from_java(cls, java_stage): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ - bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [[JavaParams._from_java(sub_model) + for sub_model in fold_sub_models] + for fold_sub_models in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -452,10 +474,16 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels] + _java_obj.setSubModels(java_sub_models) return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ .. note:: Experimental @@ -490,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None) + parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) @@ -503,10 +531,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -539,11 +567,19 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [None for i in range(numModels)] + + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric + if collectSubModelsParam: + subModels[j] = subModel + train.unpersist() validation.unpersist() @@ -552,7 +588,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) @since("2.0.0") def copy(self, extra=None): @@ -596,9 +632,11 @@ def _from_java(cls, java_stage): trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed, parallelism=parallelism) + trainRatio=trainRatio, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -618,7 +656,7 @@ def _to_java(self): _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) - + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -631,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[]): + def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() - #: best model from cross validation + #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics + #: sub models from train validation split + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -649,6 +689,7 @@ def copy(self, extra=None): creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -657,7 +698,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) - return TrainValidationSplitModel(bestModel, validationMetrics) + subModels = self.subModels + return TrainValidationSplitModel(bestModel, validationMetrics, subModels) @since("2.3.0") def write(self): @@ -685,6 +727,10 @@ def _from_java(cls, java_stage): py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [JavaParams._from_java(sub_model) + for sub_model in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -706,6 +752,11 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + _java_obj.setSubModels(java_sub_models) + return _java_obj @@ -727,4 +778,4 @@ def _to_java(self): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3c47bd79459..a486c6a3fdeb 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -169,6 +169,10 @@ def overwrite(self): self._jwrite.overwrite() return self + def option(self, key, value): + self._jwrite.option(key, value) + return self + def context(self, sqlContext): """ Sets the SQL context to use for saving. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5e..d325633195dd 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. @@ -118,11 +118,18 @@ def _transfer_params_to_java(self): """ Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap() + pair_defaults = [] for param in self.params: - if param in paramMap: - pair = self._make_java_param_pair(param, paramMap[param]) + if self.isSet(param): + pair = self._make_java_param_pair(param, self._paramMap[param]) self._java_obj.set(pair) + if self.hasDefault(param): + pair = self._make_java_param_pair(param, self._defaultParamMap[param]) + pair_defaults.append(pair) + if len(pair_defaults) > 0: + sc = SparkContext._active_spark_context + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) + self._java_obj.setDefault(pair_defaults_seq) def _transfer_param_map_to_java(self, pyParamMap): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cce703d432b5..bb281981fd56 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -16,6 +16,7 @@ # from math import exp +import sys import warnings import numpy @@ -761,7 +762,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index bb687a7da6ff..0cbabab13a89 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1048,7 +1048,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 2cd1da3fbf9a..36cb03369b8c 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since @@ -542,7 +543,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index e5231dc3a27a..40ecd2e0ff4b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -819,7 +819,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": sys.path.pop(0) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f58ea5dfb087..de18dad1f675 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + import numpy from numpy import array from collections import namedtuple @@ -197,7 +199,7 @@ def _test(): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 7b24b3c74a9f..60d96d8d5ceb 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1370,7 +1370,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 4cb802514be5..bba88542167a 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1377,7 +1377,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 61213ddf62e8..a8833cb44692 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -19,6 +19,7 @@ Python package for random data generation. """ +import sys from functools import wraps from pyspark import since @@ -421,7 +422,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 81182881352b..3d4eae85132b 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,6 +16,7 @@ # import array +import sys from collections import namedtuple from pyspark import SparkContext, since @@ -326,7 +327,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index ea107d400621..6be45f51862c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -15,9 +15,11 @@ # limitations under the License. # +import sys +import warnings + import numpy as np from numpy import array -import warnings from pyspark import RDD, since from pyspark.streaming.dstream import DStream @@ -837,7 +839,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 49b26446dbc3..3c75b132ecad 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -313,7 +313,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f108..4c2ce137e331 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,14 +1763,25 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 619fa16d463f..b05734ce489d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -17,6 +17,7 @@ from __future__ import absolute_import +import sys import random from pyspark import SparkContext, RDD, since @@ -654,7 +655,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 97755807ef26..fc7809387b13 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -521,7 +521,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 44d17bd62947..3c7656ab5758 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -19,6 +19,7 @@ import pstats import os import atexit +import sys from pyspark.accumulators import AccumulatorParam @@ -173,4 +174,4 @@ def stats(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b747..ac127ac5d61c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -51,6 +51,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -68,8 +69,9 @@ class PythonEvalType(object): SQL_BATCHED_UDF = 100 - SQL_PANDAS_SCALAR_UDF = 200 - SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_SCALAR_PANDAS_UDF = 200 + SQL_GROUPED_MAP_PANDAS_UDF = 201 + SQL_GROUPED_AGG_PANDAS_UDF = 202 def portable_hash(x): @@ -331,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -346,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -409,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -766,7 +768,7 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - s = str(obj).rstrip('\n') + '\n' + s = unicode(obj).rstrip('\n') + '\n' out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() @@ -790,6 +792,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -839,6 +843,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -910,6 +916,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -942,6 +950,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1635,6 +1646,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: @@ -2497,7 +2510,7 @@ def _test(): globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 88d6a191babc..15753f77bd90 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ xrange = range from pyspark import cloudpickle +from pyspark.util import _exception_message __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] @@ -230,6 +231,10 @@ def create_array(s, t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] @@ -245,6 +250,15 @@ def __init__(self, timezone): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.types import from_arrow_type, \ + _check_series_convert_date, _check_series_localize_timestamps + + s = arrow_column.to_pandas() + s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _check_series_localize_timestamps(s, self._timezone) + return s + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -267,13 +281,11 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) + for batch in reader: - # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) - yield [c for _, c in pdf.iteritems()] + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" @@ -558,7 +570,18 @@ def loads(self, obj, encoding=None): class CloudPickleSerializer(PickleSerializer): def dumps(self, obj): - return cloudpickle.dumps(obj, 2) + try: + return cloudpickle.dumps(obj, 2) + except pickle.PickleError: + raise + except Exception as e: + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg + else: + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) + cloudpickle.print_exec(sys.stderr) + raise pickle.PicklingError(msg) class MarshalSerializer(FramedSerializer): @@ -692,4 +715,4 @@ def write_with_length(obj, stream): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e974cda9fc3e..bd0ac0039ffe 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -23,10 +23,12 @@ import itertools import operator import random +import sys import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -93,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): @@ -810,4 +812,4 @@ def load_partition(j): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0..b0d8357f4fee 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from collections import namedtuple @@ -224,42 +225,17 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) - @ignore_unicode_prefix @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - :param name: name of the UDF - :param f: python function - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` - - >>> strlen = spark.catalog.registerFunction("stringLengthString", len) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] + def registerFunction(self, name, f, returnType=None): + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - self._jsparkSession.udf().registerPython(name, udf._judf) - return udf._wrapped() + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self._sparkSession.udf.register(name, f, returnType) @since(2.0) def isCached(self, tableName): @@ -331,7 +307,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 43b38a2cd477..e7dec11c69b5 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,7 +16,6 @@ # import sys -import warnings import json if sys.version >= '3': @@ -448,24 +447,72 @@ def isin(self, *cols): # order _asc_doc = """ - Returns a sort expression based on the ascending order of the given column name + Returns a sort expression based on ascending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc()).collect() [Row(name=u'Alice'), Row(name=u'Tom')] """ + _asc_nulls_first_doc = """ + Returns a sort expression based on ascending order of the column, and null values + return before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + + .. versionadded:: 2.4 + """ + _asc_nulls_last_doc = """ + Returns a sort expression based on ascending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + + .. versionadded:: 2.4 + """ _desc_doc = """ - Returns a sort expression based on the descending order of the given column name. + Returns a sort expression based on the descending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc()).collect() [Row(name=u'Tom'), Row(name=u'Alice')] """ + _desc_nulls_first_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + + .. versionadded:: 2.4 + """ + _desc_nulls_last_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + + .. versionadded:: 2.4 + """ asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc)) + asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc)) desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) + desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc)) + desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc)) _isNull_doc = """ True if the current expression is null. @@ -660,7 +707,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 792c420ca638..db49040e17b6 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -15,7 +15,9 @@ # limitations under the License. # -from pyspark import since +import sys + +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix @@ -37,15 +39,16 @@ def set(self, key, value): @ignore_unicode_prefix @since(2.0) - def get(self, key, default=None): + def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. """ self._checkType(key, "key") - if default is None: + if default is _NoValue: return self._jconf.get(key) else: - self._checkType(default, "default") + if default is not None: + self._checkType(default, "default") return self._jconf.get(key, default) @ignore_unicode_prefix @@ -64,7 +67,6 @@ def _checkType(self, obj, identifier): def _test(): import os import doctest - from pyspark.context import SparkContext from pyspark.sql.session import SparkSession import pyspark.sql.conf @@ -80,7 +82,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef..e9ec7ba86676 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -22,16 +22,17 @@ if sys.version >= '3': basestring = unicode = str -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType +from pyspark.sql.udf import UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] +__all__ = ["SQLContext", "HiveContext"] class SQLContext(object): @@ -123,11 +124,11 @@ def setConf(self, key, value): @ignore_unicode_prefix @since(1.3) - def getConf(self, key, defaultValue=None): + def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. - If the key is not set and defaultValue is not None, return - defaultValue. If the key is not set and defaultValue is None, return + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") @@ -147,7 +148,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -172,82 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) - @ignore_unicode_prefix @since(1.2) - def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - :param name: name of the UDF - :param f: python function - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` - - >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] + def registerFunction(self, name, f, returnType=None): + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. """ - return self.sparkSession.catalog.registerFunction(name, f, returnType) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self.sparkSession.udf.register(name, f, returnType) - @ignore_unicode_prefix @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a java UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - :param name: name of the UDF - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> sqlContext.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] + """An alias for :func:`spark.udf.registerJavaFunction`. + See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - @since(2.3) - def registerJavaUDAF(self, name, javaClassName): - """Register a java UDAF so it can be used in SQL statements. - - :param name: name of the UDAF - :param javaClassName: fully qualified name of java class - - >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -559,24 +507,6 @@ def refreshTable(self, tableName): self._ssql_ctx.refreshTable(tableName) -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sqlContext): - self.sqlContext = sqlContext - - def register(self, name, f, returnType=StringType()): - return self.sqlContext.registerFunction(name, f, returnType) - - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) - - def registerJavaUDAF(self, name, javaClassName): - self.sqlContext.registerJavaUDAF(name, javaClassName) - - register.__doc__ = SQLContext.registerFunction.__doc__ - - def _test(): import os import doctest @@ -613,7 +543,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 95eca76fa988..16f8e52dead7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,7 @@ import warnings -from pyspark import copy_func, since +from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer @@ -588,6 +588,8 @@ def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + :param numPartitions: int, to specify the target number of partitions + Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will @@ -612,9 +614,10 @@ def repartition(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions @@ -667,6 +670,52 @@ def repartition(self, numPartitions, *cols): else: raise TypeError("numPartitions should be an int or Column") + @since("2.4.0") + def repartitionByRange(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is range partitioned. + + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + At least one partition-by expression must be specified. + When no explicit sort order is specified, "ascending nulls first" is assumed. + + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() + 2 + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + >>> df.repartitionByRange(1, "age").rdd.getNumPartitions() + 1 + >>> data = df.repartitionByRange("age") + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return ValueError("At least one partition-by expression must be specified.") + else: + return DataFrame( + self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions,) + cols + return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int, string or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -819,11 +868,36 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @since(2.3) + def colRegex(self, colName): + """ + Selects column based on the column name specified as a regex and returns it + as :class:`Column`. + + :param colName: string, column name specified as a regex. + + >>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"]) + >>> df.select(df.colRegex("`(Col1)?+.+`")).show() + +----+ + |Col2| + +----+ + | 1| + | 2| + | 3| + +----+ + """ + if not isinstance(colName, basestring): + raise ValueError("colName should be provided as string") + jc = self._jdf.colRegex(colName) + return Column(jc) + @ignore_unicode_prefix @since(1.3) def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. + :param alias: string, an alias name to be set for the DataFrame. + >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") @@ -1364,7 +1438,8 @@ def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. - This is equivalent to `EXCEPT` in SQL. + This is equivalent to `EXCEPT DISTINCT` in SQL. + """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) @@ -1508,7 +1583,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1521,8 +1596,8 @@ def replace(self, to_replace, value=None, subset=None): :param to_replace: bool, int, long, float, string, list or dict. Value to be replaced. - If the value is a dict, then `value` is ignored and `to_replace` must be a - mapping between a value and a replacement. + If the value is a dict, then `value` is ignored or can be omitted, and `to_replace` + must be a mapping between a value and a replacement. :param value: bool, int, long, float, string, list or None. The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. @@ -1553,7 +1628,7 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ - >>> df4.na.replace('Alice').show() + >>> df4.na.replace({'Alice': None}).show() +----+------+----+ | age|height|name| +----+------+----+ @@ -1573,6 +1648,12 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ """ + if value is _NoValue: + if isinstance(to_replace, dict): + value = None + else: + raise TypeError("value argument is required when to_replace is not a dictionary.") + # Helper functions def all_of(types): """Given a type or tuple of types and a sequence of xs @@ -1805,11 +1886,15 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. + The column expression must be an expression over this DataFrame; attempting to add + a column from some other dataframe will raise an error. + :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) @@ -1821,7 +1906,7 @@ def withColumnRenamed(self, existing, new): This is a no-op if schema doesn't contain the given column name. :param existing: string, name of the existing column to rename. - :param col: string, new name of the column. + :param new: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] @@ -1895,6 +1980,9 @@ def toPandas(self): 0 2 Alice 1 5 Bob """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ @@ -1904,50 +1992,92 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + use_arrow = True try: - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version - import pyarrow + require_minimum_pyarrow_version() - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - pdf = table.to_pandas() - return _check_dataframe_localize_timestamps(pdf, timezone) + to_arrow_schema(self.schema) + except Exception as e: + + if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempting non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False else: - return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (_exception_message(e), msg)) + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + import pyarrow + + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and can not continue. Note that " + "'spark.sql.execution.arrow.fallback.enabled' does not have an effect " + "on failures in the middle of computation.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - - dtype = {} + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ @@ -1991,7 +2121,6 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. - NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -2002,8 +2131,6 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 - elif type(dt) == DateType: - return 'datetime64[ns]' else: return None @@ -2027,7 +2154,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ @@ -2105,7 +2232,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b..ac3c79766702 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,7 +18,6 @@ """ A collections of builtin functions """ -import math import sys import functools import warnings @@ -28,10 +27,10 @@ from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType +# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf @@ -106,18 +105,15 @@ def _(): _functions_1_4 = { # unary math functions - 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + - '0.0 through pi.', - 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2', + 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', + 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', + 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': """Computes the cosine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'cos': """:param col: angle in radians + :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""", + 'cosh': """:param col: hyperbolic angle + :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""", 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', 'floor': 'Computes the floor of the given value.', @@ -127,19 +123,32 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': """Computes the sine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': """Computes the tangent of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'sin': """:param col: angle in radians + :return: sine of the angle, as if computed by `java.lang.Math.sin()`""", + 'sinh': """:param col: hyperbolic angle + :return: hyperbolic sine of the given value, + as if computed by `java.lang.Math.sinh()`""", + 'tan': """:param col: angle in radians + :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""", + 'tanh': """:param col: hyperbolic angle + :return: hyperbolic tangent of the given value, + as if computed by `java.lang.Math.tanh()`""", 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', } +_functions_2_4 = { + 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values return before non-null values.', + 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values appear after non-null values.', + 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear before non-null values.', + 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear after non-null values', +} + _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. @@ -173,16 +182,31 @@ def _(): _functions_2_1 = { # unary math functions - 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'degrees': """ + Converts an angle measured in radians to an approximately equivalent angle + measured in degrees. + :param col: angle in radians + :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()` + """, + 'radians': """ + Converts an angle measured in degrees to an approximately equivalent angle + measured in radians. + :param col: angle in degrees + :return: angle in radians, as if computed by `java.lang.Math.toRadians()` + """, } # math functions that take two arguments as input _binary_mathfunctions = { - 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta). Units in radians.', + 'atan2': """ + :param col1: coordinate on y-axis + :param col2: coordinate on x-axis + :return: the `theta` component of the point + (`r`, `theta`) + in polar coordinates that corresponds to the point + (`x`, `y`) in Cartesian coordinates, + as if computed by `java.lang.Math.atan2()` + """, 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } @@ -237,6 +261,8 @@ def _(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc @@ -809,6 +835,36 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +@since(2.4) +def unboundedPreceding(): + """ + Window function: returns the special frame boundary that represents the first row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedPreceding()) + + +@since(2.4) +def unboundedFollowing(): + """ + Window function: returns the special frame boundary that represents the last row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedFollowing()) + + +@since(2.4) +def currentRow(): + """ + Window function: returns the special frame boundary that represents the current row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.currentRow()) + + # ---------------------- Date/Timestamp functions ------------------------------ @since(1.5) @@ -1032,16 +1088,20 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ Returns the number of months between date1 and date2. + Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) @@ -1358,7 +1418,6 @@ def hash(*cols): 'uppercase. Words are delimited by whitespace.', 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', - 'reverse': 'Reverses the string column and returns it as a new string column.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', @@ -1370,21 +1429,6 @@ def hash(*cols): del _name, _doc -@since(1.5) -@ignore_unicode_prefix -def concat(*cols): - """ - Concatenates multiple input columns together into a single column. - If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) - - @since(1.5) @ignore_unicode_prefix def concat_ws(sep, *cols): @@ -1705,10 +1749,12 @@ def unhex(col): @ignore_unicode_prefix @since(1.5) def length(col): - """Calculates the length of a string or binary expression. + """Computes the character length of string data or number of bytes of binary data. + The length of character data includes the trailing spaces. The length of binary data + includes binary zeros. - >>> spark.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() - [Row(length=3)] + >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=4)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.length(_to_java_column(col))) @@ -1737,8 +1783,8 @@ def translate(srcCol, matching, replace): def create_map(*cols): """Creates a new map column. - :param cols: list of column names (string) or list of :class:`Column` expressions that grouped - as key-value pairs, e.g. (key1, value1, key2, value2, ...). + :param cols: list of column names (string) or list of :class:`Column` expressions that are + grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). >>> df.select(create_map('name', 'age').alias("map")).collect() [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] @@ -1788,6 +1834,100 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input columns together into a single column. + The function works with strings, binary and compatible array columns. + + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + +@since(2.4) +def array_position(col, value): + """ + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. + + .. note:: The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) + + +@ignore_unicode_prefix +@since(2.4) +def element_at(col, extraction): + """ + Collection function: Returns element of array at given index in extraction if col is array. + Returns value for the given key in extraction if col is map. + + :param col: name of column containing array or map + :param extraction: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -1849,14 +1989,14 @@ def explode_outer(col): +---+----------+----+-----+ >>> df.select("id", "a_map", explode_outer("an_array")).show() - +---+-------------+----+ - | id| a_map| col| - +---+-------------+----+ - | 1|Map(x -> 1.0)| foo| - | 1|Map(x -> 1.0)| bar| - | 2| Map()|null| - | 3| null|null| - +---+-------------+----+ + +---+----------+----+ + | id| a_map| col| + +---+----------+----+ + | 1|[x -> 1.0]| foo| + | 1|[x -> 1.0]| bar| + | 2| []|null| + | 3| null|null| + +---+----------+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.explode_outer(_to_java_column(col)) @@ -1881,14 +2021,14 @@ def posexplode_outer(col): | 3| null|null|null| null| +---+----------+----+----+-----+ >>> df.select("id", "a_map", posexplode_outer("an_array")).show() - +---+-------------+----+----+ - | id| a_map| pos| col| - +---+-------------+----+----+ - | 1|Map(x -> 1.0)| 0| foo| - | 1|Map(x -> 1.0)| 1| bar| - | 2| Map()|null|null| - | 3| null|null|null| - +---+-------------+----+----+ + +---+----------+----+----+ + | id| a_map| pos| col| + +---+----------+----+----+ + | 1|[x -> 1.0]| 0| foo| + | 1|[x -> 1.0]| 1| bar| + | 2| []|null|null| + | 3| null|null|null| + +---+----------+----+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) @@ -2022,24 +2162,108 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_min(col): + """ + Collection function: returns the minimum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_min(_to_java_column(col))) + + +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] - """ + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ @@ -2085,9 +2309,11 @@ def map_values(col): class PandasUDFType(object): """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. """ - SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF + + GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @since(1.3) @@ -2103,12 +2329,15 @@ def udf(f=None, returnType=StringType()): >>> import random >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + .. note:: The user-defined functions do not take keyword arguments on the calling side. + :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -2148,7 +2377,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. @@ -2157,7 +2387,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2184,20 +2414,25 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 8| JOHN DOE| 22| +----------+--------------+------------+ - 2. GROUP_MAP + .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input + column, but is the length of an internal batch used for each call to the function. + Therefore, this can be used, for example, to ensure the length of each returned + `pandas.Series`, and can not be used as the column length. + + 2. GROUPED_MAP - A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` + A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. The length of the returned `pandas.DataFrame` can be arbitrary. - Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -2212,13 +2447,79 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 1.1094003924504583| +---+-------------------+ + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key will be passed as the first argument and the data will + be passed as the second argument. The grouping key will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key in the function. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + 3. GROUPED_AGG + + A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP + + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + + .. note:: The user-defined functions do not take keyword arguments on the calling side. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) @@ -2237,20 +2538,21 @@ def pandas_udf(f=None, returnType=None, functionType=None): eval_type = returnType else: # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF else: return_type = returnType if functionType is not None: eval_type = functionType else: - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF if return_type is None: raise ValueError("Invalid returnType: returnType can not be None") - if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") @@ -2284,7 +2586,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 09fae46adf01..3505065b648f 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,11 +15,12 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal +from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -65,13 +66,27 @@ def __init__(self, jgd, df): def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + The available aggregate functions can be: + + 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` + + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` + + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -82,6 +97,13 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def min_udf(v): + ... return v.min() + >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP + [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -204,21 +226,24 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: a grouped map user-defined function returned by + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -238,9 +263,9 @@ def apply(self, udf): """ # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " - "GROUP_MAP.") + "GROUPED_MAP.") df = self._df udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) @@ -275,7 +300,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 49af1bcee5ef..448a4732001b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,7 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD, since, keyword_only +from pyspark import RDD, since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * @@ -147,8 +147,8 @@ def load(self, path=None, format=None, schema=None, **options): or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options - >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, - ... opt2=1, opt3='str') + >>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_partitioned', + ... opt1=True, opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -209,13 +210,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -237,6 +238,14 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +263,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -304,16 +314,18 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths, wholetext=False): + def text(self, paths, wholetext=False, lineSep=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() @@ -322,7 +334,7 @@ def text(self, paths, wholetext=False): >>> df.collect() [Row(value=u'hello\\nthis')] """ - self._set_opts(wholetext=wholetext) + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -333,7 +345,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -393,13 +406,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -414,6 +429,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -432,7 +449,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -742,7 +759,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -766,12 +784,17 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) @@ -802,18 +825,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): self._jwrite.parquet(path) @since(1.6) - def text(self, path, compression=None): + def text(self, path, compression=None, lineSep=None): """Saves the content of the DataFrame in a text file at the specified path. :param path: the path in any Hadoop supported file system :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - self._set_opts(compression=compression) + self._set_opts(compression=compression, lineSep=lineSep) self._jwrite.text(path) @since(2.0) @@ -954,7 +979,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) @@ -968,7 +993,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) sc.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6e5eec48e8ac..13d6e2e53dbd 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -29,7 +29,6 @@ from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.sql.catalog import Catalog from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -214,7 +213,12 @@ def __init__(self, sparkContext, jsparkSession=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if jsparkSession is None: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if self._jvm.SparkSession.getDefaultSession().isDefined() \ + and not self._jvm.SparkSession.getDefaultSession().get() \ + .sparkContext().isStopped(): + jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + else: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) @@ -226,6 +230,7 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + self._jvm.SparkSession.setDefaultSession(self._jsparkSession) def _repr_html_(self): return """ @@ -280,6 +285,7 @@ def catalog(self): :return: :class:`Catalog` """ + from pyspark.sql.catalog import Catalog if not hasattr(self, "_catalog"): self._catalog = Catalog(self) return self._catalog @@ -291,8 +297,8 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.context import UDFRegistration - return UDFRegistration(self._wrapped) + from pyspark.sql.udf import UDFRegistration + return UDFRegistration(self) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): @@ -325,11 +331,12 @@ def range(self, start, end=None, step=1, numPartitions=None): return DataFrame(jdf, self._wrapped) - def _inferSchemaFromList(self, data): + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple + :param names: list of column names :return: :class:`pyspark.sql.types.StructType` """ if not data: @@ -338,12 +345,12 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) + schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema - def _inferSchema(self, rdd, samplingRatio=None): + def _inferSchema(self, rdd, samplingRatio=None, names=None): """ Infer schema from an RDD of Row or tuple. @@ -360,10 +367,10 @@ def _inferSchema(self, rdd, samplingRatio=None): "Use pyspark.sql.Row instead") if samplingRatio is None: - schema = _infer_schema(first) + schema = _infer_schema(first, names=names) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) + schema = _merge_type(schema, _infer_schema(row, names=names)) if not _has_nulltype(schema): break else: @@ -372,7 +379,7 @@ def _inferSchema(self, rdd, samplingRatio=None): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): @@ -380,7 +387,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) + struct = self._inferSchema(rdd, samplingRatio, names=schema) converter = _create_converter(struct) rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): @@ -406,7 +413,7 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) + struct = self._inferSchemaFromList(data, names=schema) converter = _create_converter(struct) data = map(converter, data) if isinstance(schema, (list, tuple)): @@ -458,21 +465,23 @@ def _convert_from_pandas(self, pdf, schema, timezone): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if isinstance(field.dataType, TimestampType): s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if not copied and s is not pdf[field.name]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s else: for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) - if not copied and s is not pdf[column]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) @@ -637,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ == "true": timezone = self.conf.get("spark.sql.session.timeZone") @@ -645,15 +657,36 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - # Fallback to create DataFrame without arrow if raise some exception + from pyspark.util import _exception_message + + if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempting non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): @@ -754,6 +787,8 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) @@ -796,7 +831,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7a..15f940738986 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,8 +24,6 @@ else: intlike = (int, long) -from abc import ABCMeta, abstractmethod - from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq @@ -407,7 +405,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -442,13 +440,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -470,6 +468,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -484,7 +484,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -531,17 +531,20 @@ def parquet(self, path): @ignore_unicode_prefix @since(2.0) - def text(self, path): + def text(self, path, wholetext=False, lineSep=None): """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. .. note:: Evolving. :param paths: string, or list of strings, for input path(s). + :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming @@ -549,6 +552,7 @@ def text(self, path): >>> "value" in str(text_sdf.schema) True """ + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.text(path)) else: @@ -621,13 +625,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -786,35 +792,54 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None, once=None): + def trigger(self, processingTime=None, once=None, continuous=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') >>> # trigger the query for just once batch of data >>> writer = sdf.writeStream.trigger(once=True) + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(continuous='5 seconds') """ + params = [processingTime, once, continuous] + + if params.count(None) == 3: + raise ValueError('No trigger provided') + elif params.count(None) < 2: + raise ValueError('Multiple triggers not allowed.') + jTrigger = None if processingTime is not None: - if once is not None: - raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( interval) + elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: - raise ValueError('No trigger provided') + if type(continuous) != str or len(continuous.strip()) == 0: + raise ValueError('Value for continuous must be a non empty string. Got: %s' % + continuous) + interval = continuous.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( + interval) + self._jwrite = self._jwrite.trigger(jTrigger) return self @@ -909,7 +934,7 @@ def _test(): globs['spark'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93..1f91d2c18168 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,7 +32,9 @@ import datetime import array import ctypes +import warnings import py4j +from contextlib import contextmanager try: import xmlrunner @@ -48,19 +50,26 @@ else: import unittest -_have_pandas = False -_have_old_pandas = False +from pyspark.util import _exception_message + +_pandas_requirement_message = None try: - import pandas - try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - _have_pandas = True - except: - _have_old_pandas = True -except: - # No Pandas, but that's okay, we'll skip those tests - pass + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() +except ImportError as e: + # If Pandas version requirement is not satisfied, skip related tests. + _pandas_requirement_message = _exception_message(e) + +_pyarrow_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() +except ImportError as e: + # If Arrow version requirement is not satisfied, skip related tests. + _pyarrow_requirement_message = _exception_message(e) + +_have_pandas = _pandas_requirement_message is None +_have_pyarrow = _pyarrow_requirement_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -68,21 +77,13 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.sql.types import _merge_type +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -185,7 +186,38 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ + + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): @classmethod def setUpClass(cls): ReusedPySparkTestCase.setUpClass() @@ -196,6 +228,12 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -371,6 +409,12 @@ def test_udf(self): [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + def test_udf2(self): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ @@ -378,6 +422,81 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], 6) + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], 6) + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) + + def test_nondeterministic_udf3(self): + # regression test for SPARK-23233 + from pyspark.sql.functions import udf + f = udf(lambda x: x) + # Here we cache the JVM UDF instance. + self.spark.range(1).select(f("id")) + # This should reset the cache to set the deterministic status correctly. + f = f.asNondeterministic() + # Check the deterministic status of udf. + df = self.spark.range(1).select(f("id")) + deterministic = df._jdf.logicalPlan().projectList().head().deterministic() + self.assertFalse(deterministic) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +554,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -523,22 +633,82 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + def test_non_existed_udaf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) @@ -567,7 +737,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +744,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +803,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -733,6 +900,22 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -861,6 +1044,15 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), @@ -881,6 +1073,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data) @@ -1016,6 +1212,14 @@ def myudf(x): rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1, 2), (1, 2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) @@ -1105,6 +1309,17 @@ def test_union_with_udt(self): ] ) + def test_cast_to_string_with_udt(self): + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + def test_column_operators(self): ci = self.df.key cs = self.df.value @@ -1299,7 +1514,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1582,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) @@ -1456,6 +1669,12 @@ def test_stream_trigger(self): except ValueError: pass + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + # Should take only keyword args try: df.writeStream.trigger('5 seconds') @@ -1737,6 +1956,92 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() @@ -1938,6 +2243,34 @@ def test_expr(self): result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) + def test_repartitionByRange_dataframe(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + df1 = self.spark.createDataFrame( + [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) + df2 = self.spark.createDataFrame( + [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) + + # test repartitionByRange(numPartitions, *cols) + df3 = df1.repartitionByRange(2, "name", "age") + self.assertEqual(df3.rdd.getNumPartitions(), 2) + self.assertEqual(df3.rdd.first(), df2.rdd.first()) + self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(numPartitions, *cols) + df4 = df1.repartitionByRange(3, "name", "age") + self.assertEqual(df4.rdd.getNumPartitions(), 3) + self.assertEqual(df4.rdd.first(), df2.rdd.first()) + self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(*cols) + df5 = df1.repartitionByRange("name", "age") + self.assertEqual(df5.rdd.first(), df2.rdd.first()) + self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), @@ -2033,11 +2366,6 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) - # replace list while value is not given (default to None) - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - self.assertTupleEqual(row, (None, 10, 80.0)) - # replace string with None and then drop None rows row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() @@ -2073,6 +2401,12 @@ def test_replace(self): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + with self.assertRaisesRegexp( + TypeError, + 'value argument is required when to_replace is not a dictionary.'): + self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) @@ -2198,17 +2532,13 @@ def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") - try: - self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) - finally: - # We should unset this. Otherwise, other tests are affected. - self.spark.conf.unset("spark.sql.crossJoin.enabled") # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): @@ -2241,6 +2571,17 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + def test_current_database(self): spark = self.spark spark.catalog._reset() @@ -2582,7 +2923,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"): def _to_pandas(self): from datetime import datetime, date - import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType())\ .add("dt", DateType()).add("ts", TimestampType()) @@ -2595,7 +2935,7 @@ def _to_pandas(self): df = self.spark.createDataFrame(data, schema) return df.toPandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas(self): import numpy as np pdf = self._to_pandas() @@ -2604,16 +2944,16 @@ def test_to_pandas(self): self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_to_pandas_old(self): + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_to_pandas_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas_avoid_astype(self): import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ @@ -2631,7 +2971,7 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime @@ -2646,19 +2986,99 @@ def test_create_dataframe_from_pandas_with_timestamp(self): self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_create_dataframe_from_old_pandas(self): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): + with self.assertRaisesRegexp( + ImportError, + "(Pandas >= .* must be installed|No module named '?pandas'?)"): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) + # Regression test for SPARK-23360 + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) + def test_create_dateframe_from_pandas_with_dst(self): + import pandas as pd + from datetime import datetime + + pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + + orig_env_tz = os.environ.get('TZ', None) + try: + tz = 'America/Los_Angeles' + os.environ['TZ'] = tz + time.tzset() + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + finally: + del os.environ['TZ'] + if orig_env_tz is not None: + os.environ['TZ'] = orig_env_tz + time.tzset() + + def test_sort_with_nulls_order(self): + from pyspark.sql import functions + + df = self.spark.createDataFrame( + [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by @@ -2688,8 +3108,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2713,6 +3133,95 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "spark"): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: @@ -2742,18 +3251,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -3145,12 +3658,14 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - from datetime import datetime + from datetime import date, datetime from decimal import Decimal ReusedSQLTestCase.setUpClass() @@ -3162,6 +3677,8 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3172,11 +3689,11 @@ def setUpClass(cls): StructField("7_date_t", DateType(), True), StructField("8_timestamp_t", TimestampType(), True)]) cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3186,12 +3703,6 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3203,11 +3714,28 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() def test_null_conversion(self): @@ -3218,33 +3746,35 @@ def test_null_conversion(self): self.assertTrue(all([c == 1 for c in null_counts])) def _toPandas_arrow_toggle(self, df): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): pdf = df.toPandas() - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + pdf_arrow = df.toPandas() + return pdf, pdf_arrow def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow, pdf) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3254,15 +3784,13 @@ def test_toPandas_respect_session_timezone(self): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3272,12 +3800,11 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow def test_createDataFrame_toggle(self): @@ -3288,18 +3815,18 @@ def test_createDataFrame_toggle(self): def test_createDataFrame_respect_session_timezone(self): from datetime import timedelta pdf = self.create_pandas_data_frame() - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) result_ny = df_no_arrow_ny.collect() result_arrow_ny = df_arrow_ny.collect() @@ -3312,21 +3839,19 @@ def test_createDataFrame_respect_session_timezone(self): for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3351,7 +3876,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3397,8 +3922,58 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType @@ -3406,34 +3981,34 @@ def test_pandas_udf_basic(self): udf = pandas_udf(lambda x: x, DoubleType()) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), - PandasUDFType.GROUP_MAP) + PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, returnType='v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_pandas_udf_decorator(self): from pyspark.rdd import PythonEvalType @@ -3444,45 +4019,45 @@ def test_pandas_udf_decorator(self): def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=DoubleType()) def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) schema = StructType([StructField("v", DoubleType())]) - @pandas_udf(schema, PandasUDFType.GROUP_MAP) + @pandas_udf(schema, PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf('v double', PandasUDFType.GROUP_MAP) + @pandas_udf('v double', PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) def foo(x): return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_udf_wrong_arg(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -3509,21 +4084,23 @@ def zero_with_type(): return 1 with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df - with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): - @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): - @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) - def foo(k, v): + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) + def foo(k, v, w): return k -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedSQLTestCase): +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -3546,8 +4123,20 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def nondeterministic_vectorized_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -3555,7 +4144,8 @@ def test_vectorized_udf_basic(self): col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) + col('id').cast('boolean').alias('bool'), + array(col('id')).alias('array_long')) f = lambda x: x str_f = pandas_udf(f, StringType()) int_f = pandas_udf(f, IntegerType()) @@ -3564,12 +4154,28 @@ def test_vectorized_udf_basic(self): double_f = pandas_udf(f, DoubleType()) decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) + array_long_f = pandas_udf(f, ArrayType(LongType())) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) + bool_f(col('bool')), array_long_f('array_long')) self.assertEquals(df.collect(), res.collect()) + def test_register_nondeterministic_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + import random + random_pandas_udf = pandas_udf( + lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + def test_vectorized_udf_null_boolean(self): from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] @@ -3652,6 +4258,15 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) + def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3748,10 +4363,11 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.select(f(col('id'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): from pyspark.sql.functions import pandas_udf, col @@ -3786,26 +4402,55 @@ def test_vectorized_udf_varargs(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + from pyspark.sql.functions import pandas_udf + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('map'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) - def test_vectorized_udf_null_date(self): + def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date - schema = StructType().add("date", DateType()) - data = [(date(1969, 1, 1),), - (date(2012, 2, 2),), - (None,), - (date(2100, 4, 4),)] + schema = StructType().add("idx", LongType()).add("date", DateType()) + data = [(0, date(1969, 1, 1),), + (1, date(2012, 2, 2),), + (2, None,), + (3, date(2100, 4, 4),)] df = self.spark.createDataFrame(data, schema=schema) - date_f = pandas_udf(lambda t: t, returnType=DateType()) - res = df.select(date_f(col("date"))) - self.assertEquals(df.collect(), res.collect()) + + date_copy = pandas_udf(lambda t: t, returnType=DateType()) + df = df.withColumn("date_copy", date_copy(col("date"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, date, date_copy): + import pandas as pd + msgs = [] + is_equal = date.isnull() + for i in range(len(idx)): + if (is_equal[i] and data[idx[i]][1] is None) or \ + date[i] == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "date values are not equal (date='%s': data[%d][1]='%s')" + % (date[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", + check_data(col("idx"), col("date"), col("date_copy"))).collect() + + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][1], result[i][2]) # "date_copy" col + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): from pyspark.sql.functions import pandas_udf, col @@ -3846,6 +4491,7 @@ def check_data(idx, timestamp, timestamp_copy): self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): @@ -3869,9 +4515,7 @@ def gen_timestamps(id): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd - orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) - try: + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @pandas_udf(returnType=LongType()) @@ -3881,11 +4525,6 @@ def check_records_per_batch(x): result = df.select(check_records_per_batch(col("id"))).collect() for (r,) in result: self.assertTrue(r <= 3) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - else: - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) def test_vectorized_udf_timestamps_respect_session_timezone(self): from pyspark.sql.functions import pandas_udf, col @@ -3904,40 +4543,107 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): internal_value = pandas_udf( lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() self.assertNotEqual(result_ny, result_la) self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_vectorized_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedSQLTestCase): + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.nondeterministic_vectorized_udf - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_vectorized_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.nondeterministic_vectorized_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + + def test_register_vectorized_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b')) + original_add = pandas_udf(lambda x, y: x + y, IntegerType()) + self.assertEqual(original_add.deterministic, True) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + new_add = self.spark.catalog.registerFunction("add1", original_add) + res1 = df.select(new_add(col('a'), col('b'))) + res2 = self.spark.sql( + "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") + expected = df.select(expr('a + b')) + self.assertEquals(expected.collect(), res1.collect()) + self.assertEquals(expected.collect(), res2.collect()) + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda x: x, 'timestamp') + result = df.withColumn('time', foo_udf(df.time)) + self.assertEquals(df.collect(), result.collect()) + + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") + def test_type_annotation(self): + from pyspark.sql.functions import pandas_udf + # Regression test to check if type hints can be used. See SPARK-23569. + # Note that it throws an error during compilation in lower Python versions if 'exec' + # is not used. Also, note that we explicitly use another dictionary to avoid modifications + # in the current 'locals()'. + # + # Hyukjin: I think it's an ugly way to test issues about syntax specific in + # higher versions of Python, which we shouldn't encourage. This was the last resort + # I could come up with at that time. + _locals = {} + exec( + "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", + _locals) + df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + self.assertEqual(df.first()[0], 0) + + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -3946,23 +4652,58 @@ def data(self): .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data + def test_supported_types(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + df = self.data.withColumn("arr", array(col("id"))) - foo_udf = pandas_udf( + # Different forms of group map pandas UDF, results of these are the same + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), + StructField('v1', DoubleType()), + StructField('v2', LongType())]) + + udf1 = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('v1', DoubleType()), - StructField('v2', LongType())]), - PandasUDFType.GROUP_MAP + output_schema, + PandasUDFType.GROUPED_MAP ) - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + udf2 = pandas_udf( + lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_register_grouped_map_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' + 'SQL_SCALAR_PANDAS_UDF'): + self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -3970,14 +4711,14 @@ def test_decorator(self): @pandas_udf( 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -3986,13 +4727,13 @@ def test_coerce(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4000,7 +4741,7 @@ def test_complex_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4011,7 +4752,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4019,7 +4760,7 @@ def test_empty_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4030,7 +4771,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4039,26 +4780,24 @@ def test_datatype_string(self): foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUP_MAP - ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.groupby('id').apply(foo).sort('id').toPandas() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf( + lambda pdf: pdf, + 'id long, v map', + PandasUDFType.GROUPED_MAP) def test_wrong_args(self): from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType @@ -4077,28 +4816,565 @@ def test_wrong_args(self): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), - PandasUDFType.SCALAR)) + pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType + from pyspark.sql.functions import pandas_udf, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.groupby('id').apply(f).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + schema = StructType( + [StructField("id", LongType(), True), + StructField("arr_ts", ArrayType(TimestampType()), True)]) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) + result = df.groupby('time').apply(foo_udf).sort('time') + self.assertPandasEqual(df.toPandas(), result.toPandas()) + + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class GroupedAggPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, (int, float)) + return v + 1 + return plus_one + + @property + def pandas_scalar_plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def sum(v): + return v.sum() + return sum + + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + return weighted_mean + + def test_manual(self): + from pyspark.sql.functions import pandas_udf, array + + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + mean_arr_udf = pandas_udf( + self.pandas_agg_mean_udf.func, + ArrayType(self.pandas_agg_mean_udf.returnType), + self.pandas_agg_mean_udf.evalType) + + result1 = df.groupby('id').agg( + sum_udf(df.v), + mean_udf(df.v), + mean_arr_udf(array(df.v))).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5, [24.5]], + [1, 255.0, 25.5, [25.5]], + [2, 265.0, 26.5, [26.5]], + [3, 275.0, 27.5, [27.5]], + [4, 285.0, 28.5, [28.5]], + [5, 295.0, 29.5, [29.5]], + [6, 305.0, 30.5, [30.5]], + [7, 315.0, 31.5, [31.5]], + [8, 325.0, 32.5, [32.5]], + [9, 335.0, 33.5, [33.5]]], + ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + # Groupby one column and aggregate one UDF with literal + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + # Groupby one expression and aggregate one UDF with literal + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + # Groupby one column and aggregate one UDF without literal + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + # Groupby one expression and aggregate one UDF without literal + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + def test_unsupported_types(self): + from pyspark.sql.types import DoubleType, MapType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + pandas_udf( + lambda x: x, + ArrayType(ArrayType(TimestampType())), + PandasUDFType.GROUPED_AGG) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF with sql expression + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + expected1 = (df.groupby('id') + .agg(sum(df.v) + 1) + .sort('id')) + + # Mix group aggregate pandas UDF with sql expression (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1)) + .sort('id')) + + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF and python UDF + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and python UDF (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + expected2 = (df.groupby('id') + .agg(sum(plus_one(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(plus_two(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v))) + .sort('id')) + + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v)))) + .sort('plus_one(id)')) + + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + expected1 = (df.groupBy('id') + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_groupby(self): + from pyspark.sql.functions import lit, sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + + # empty groupby + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v)) + + # groupby one column and one sql expression + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + + # groupby one python UDF + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + + # groupby one scalar pandas UDF + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) + + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + # Test sequential groupby aggregate + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v'))) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v'))) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_retain_group_columns(self): + from pyspark.sql.functions import sum, lit, col + with self.sql_conf({"spark.sql.retainGroupColumns": False}): + df = self.data + sum_udf = self.pandas_agg_sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v)) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.python_plus_one + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 146e673ae975..3cd7a2ef115a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,7 +35,6 @@ from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer -from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", @@ -290,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -344,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -403,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name @@ -455,9 +458,6 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. - .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead - to get a list of field names. - >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) @@ -755,41 +755,6 @@ def __eq__(self, other): _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _ignore_brackets_split(s, separator): - """ - Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. - given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return - ["a", "d"]. - """ - parts = [] - buf = "" - level = 0 - for c in s: - if c in _BRACKETS.keys(): - level += 1 - buf += c - elif c in _BRACKETS.values(): - if level == 0: - raise ValueError("Brackets are not correctly paired: %s" % s) - level -= 1 - buf += c - elif c == separator and level > 0: - buf += c - elif c == separator: - parts.append(buf) - buf = "" - else: - buf += c - - if len(buf) == 0: - raise ValueError("The %s cannot be the last char: %s" % (separator, s)) - parts.append(buf) - return parts - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals @@ -1073,7 +1038,7 @@ def _infer_type(obj): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row): +def _infer_schema(row, names=None): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1084,7 +1049,10 @@ def _infer_schema(row): elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: - names = ['_%d' % i for i in range(1, len(row) + 1)] + if names is None: + names = ['_%d' % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1109,19 +1077,27 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b): +def _merge_type(a, b, name=None): + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) + if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), + name=new_name(f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1130,11 +1106,12 @@ def _merge_type(a, b): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) + return ArrayType(_merge_type(a.elementType, b.elementType, + name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), + return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), + _merge_type(a.valueType, b.valueType, name='value of map %s' % name), True) else: return a @@ -1626,6 +1603,8 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: + if type(dt.elementType) == TimestampType: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -1668,6 +1647,8 @@ def from_arrow_type(at): elif types.is_timestamp(at): spark_type = TimestampType() elif types.is_list(at): + if types.is_timestamp(at.value_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) @@ -1682,6 +1663,71 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_series_convert_date(series, data_type): + """ + Cast the series to datetime.date if it's a date type, otherwise returns the original series. + + :param series: pandas.Series + :param data_type: a Spark data type for the series + """ + if type(data_type) == DateType: + return series.dt.date + else: + return series + + +def _check_dataframe_convert_date(pdf, schema): + """ Correct date type value to use datetime.date. + + Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should + use datetime.date to match the behavior with when Arrow optimization is disabled. + + :param pdf: pandas.DataFrame + :param schema: a Spark schema of the pandas.DataFrame + """ + for field in schema: + pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) + return pdf + + +def _get_local_timezone(): + """ Get local timezone using pytz with environment variable, or dateutil. + + If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone + string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and + it reads system configuration to know the system local timezone. + + See also: + - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 + - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 + """ + import os + return os.environ.get('TZ', 'dateutil/:') + + +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1693,12 +1739,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or 'tzlocal()' for column, series in pdf.iteritems(): - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) + pdf[column] = _check_series_localize_timestamps(series, timezone) return pdf @@ -1717,8 +1759,38 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - tz = timezone or 'tzlocal()' - return s.dt.tz_localize(tz).dt.tz_convert('UTC') + # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive + # timestamp is during the hour when the clock is adjusted backward during due to + # daylight saving time (dst). + # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to + # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize + # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either + # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). + # + # Here we explicit choose to use standard time. This matches the default behavior of + # pytz. + # + # Here are some code to help understand this behavior: + # >>> import datetime + # >>> import pandas as pd + # >>> import pytz + # >>> + # >>> t = datetime.datetime(2015, 11, 1, 1, 30) + # >>> ts = pd.Series([t]) + # >>> tz = pytz.timezone('America/New_York') + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=True) + # 0 2015-11-01 01:30:00-04:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=False) + # 0 2015-11-01 01:30:00-05:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> str(tz.localize(t)) + # '2015-11-01 01:30:00-05:00' + tz = timezone or _get_local_timezone() + return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: @@ -1739,15 +1811,16 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): import pandas as pd from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - from_tz = from_timezone or 'tzlocal()' - to_tz = to_timezone or 'tzlocal()' + from_tz = from_timezone or _get_local_timezone() + to_tz = to_timezone or _get_local_timezone() # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert(to_tz).dt.tz_localize(None) elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) - if ts is not pd.NaT else pd.NaT) + return s.apply( + lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) else: return s @@ -1785,7 +1858,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c..c8fb49d7c2b6 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -18,11 +18,16 @@ User-defined function related classes and functions """ import functools +import sys -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ + to_arrow_type, to_arrow_schema +from pyspark.util import _get_argspec, fail_on_stopiteration + +__all__ = ["UDFRegistration"] def _wrap_function(sc, func, returnType): @@ -34,29 +39,31 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ - evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - import inspect - from pyspark.sql.utils import require_minimum_pyarrow_version + if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - argspec = inspect.getargspec(f) - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ + argspec = _get_argspec(f) + + if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: raise ValueError( "Invalid function: 0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (1, 2): raise ValueError( - "Invalid function: pandas_udfs with function type GROUP_MAP " - "must take a single arg that is a pandas DataFrame." - ) + "Invalid function: pandas_udfs with function type GROUPED_MAP " + "must take either one argument (data) or two arguments (key, data).") # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +74,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +101,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -104,10 +113,31 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ - and not isinstance(self._returnType_placeholder, StructType): - raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUP_MAP") + if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with scalar Pandas UDFs: %s is " + "not supported" % str(self._returnType_placeholder)) + elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_schema(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped map Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) + else: + raise TypeError("Invalid returnType for grouped map Pandas " + "UDFs: returnType must be a StructType.") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped aggregate Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) return self._returnType_placeholder @@ -127,10 +157,20 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + func = fail_on_stopiteration(self.func) + + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + func._argspec = _get_argspec(self.func) + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +178,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,8 +205,9 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic - + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): @@ -172,5 +216,198 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + # Here, we explicitly clean the cache to create a JVM UDF instance + # with 'deterministic' updated. See SPARK-23233. + self._judf_placeholder = None + self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. This instance can be accessed by + :attr:`spark.udf` or :attr:`sqlContext.udf`. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since("1.3.1") + def register(self, name, f, returnType=None): + """Register a Python function (including lambda function) or a user-defined function + as a SQL function. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. The value can + be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + :return: a user-defined function. + + To register a nondeterministic Python function, users need to first build + a nondeterministic user-defined function for the Python function and then register it + as a SQL function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. Please see below. + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function as a SQL function. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: the return type of the registered Java function. The value can be either + a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction( + ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + + >>> spark.udf.registerJavaFunction( + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + + >>> spark.udf.registerJavaFunction( + ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") + >>> spark.sql("SELECT javaStringLength3('test')").collect() + [Row(UDF:javaStringLength3(test)=4)] + """ + + jdt = None + if returnType is not None: + if not isinstance(returnType, DataType): + returnType = _parse_datatype_string(returnType) + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function as a SQL function. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.createOrReplaceTempView("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 08c34c6dccc5..45363f089a73 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -115,18 +115,38 @@ def toJArray(gateway, jtype, arr): def require_minimum_pandas_version(): """ Raise ImportError if minimum version of Pandas is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pandas_version = "0.19.2" + from distutils.version import LooseVersion - import pandas - if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " - "however, your version was %s." % pandas.__version__) + try: + import pandas + have_pandas = True + except ImportError: + have_pandas = False + if not have_pandas: + raise ImportError("Pandas >= %s must be installed; however, " + "it was not found." % minimum_pandas_version) + if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): + raise ImportError("Pandas >= %s must be installed; however, " + "your version was %s." % (minimum_pandas_version, pandas.__version__)) def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pyarrow_version = "0.8.0" + from distutils.version import LooseVersion - import pyarrow - if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " - "however, your version was %s." % pyarrow.__version__) + try: + import pyarrow + have_arrow = True + except ImportError: + have_arrow = False + if not have_arrow: + raise ImportError("PyArrow >= %s must be installed; however, " + "it was not found." % minimum_pyarrow_version) + if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): + raise ImportError("PyArrow >= %s must be installed; however, " + "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 7ce27f9b102c..d19ced954f04 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,9 +16,11 @@ # import sys +if sys.version >= '3': + long = int from pyspark import since, SparkContext -from pyspark.sql.column import _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -42,6 +44,10 @@ class Window(object): >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + .. note:: When ordering is not defined, an unbounded window frame (rowFrame, + unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined, + a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default. + .. note:: Experimental .. versionadded:: 1.4 @@ -120,20 +126,45 @@ def rangeBetween(start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). + + >>> from pyspark.sql import functions as F, SparkSession, Window + >>> spark = SparkSession.builder.getOrCreate() + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> window = Window.orderBy("id").partitionBy("category").rangeBetween( + ... F.currentRow(), F.lit(1)) + >>> df.withColumn("sum", F.sum("id").over(window)).show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| b| 3| + | 2| b| 5| + | 3| b| 3| + | 1| a| 4| + | 1| a| 4| + | 2| a| 2| + +---+--------+---+ """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -208,29 +239,36 @@ def rangeBetween(self, start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc return WindowSpec(self._jspec.rangeBetween(start, end)) def _test(): import doctest SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod() + (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index fdb930860448..ed2e0e7d10fa 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :param topics: list of topic_name to consume. :param kafkaParams: Additional params for Kafka. :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting - point of the stream. + point of the stream (a dictionary mapping `TopicAndPartition` to + integers). :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py index b830797f5c0a..d4ecc215aea9 100644 --- a/python/pyspark/streaming/listener.py +++ b/python/pyspark/streaming/listener.py @@ -23,6 +23,12 @@ class StreamingListener(object): def __init__(self): pass + def onStreamingStarted(self, streamingStarted): + """ + Called when the streaming has been started. + """ + pass + def onReceiverStarted(self, receiverStarted): """ Called when a receiver has been started diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c39..e4a428a0b27e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -507,6 +507,10 @@ def __init__(self): self.batchInfosCompleted = [] self.batchInfosStarted = [] self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) def onBatchSubmitted(self, batchSubmitted): self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) @@ -530,9 +534,12 @@ def func(dstream): batchInfosSubmitted = batch_collector.batchInfosSubmitted batchInfosStarted = batch_collector.batchInfosStarted batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime self.wait_for(batchInfosCompleted, 4) + self.assertEqual(len(streamingStartedTime), 1) + self.assertGreaterEqual(len(batchInfosSubmitted), 4) for info in batchInfosSubmitted: self.assertGreaterEqual(info.batchTime().milliseconds(), 0) @@ -1477,8 +1484,8 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/sbt -Pkafka-0-8 assembly/package streaming-kafka-0-8-assembly/assembly' or " + "'build/mvn -DskipTests -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1494,8 +1501,8 @@ def search_flume_assembly_jar(): raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or " + "'build/mvn -DskipTests -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1503,10 +1510,13 @@ def search_flume_assembly_jar(): return jars[0] -def search_kinesis_asl_assembly_jar(): +def _kinesis_asl_assembly_dir(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") + + +def search_kinesis_asl_assembly_jar(): + jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") if not jars: return None elif len(jars) > 1: @@ -1539,7 +1549,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] @@ -1569,7 +1581,7 @@ def search_kinesis_asl_assembly_jar(): else: raise Exception( ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % kinesis_asl_assembly_dir) + + % _kinesis_asl_assembly_dir()) + "You need to build Spark with 'build/sbt -Pkinesis-asl " "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") @@ -1580,11 +1592,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394..df184471993f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,6 +18,7 @@ import time from datetime import datetime import traceback +import sys from pyspark import SparkContext, RDD @@ -147,4 +148,4 @@ def rddToFileName(prefix, suffix, timestamp): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index da99872da2f0..3b37cc028c1b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1239,6 +1270,35 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + + def test_stopiteration_in_client_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): @@ -1944,7 +2004,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2010,7 +2075,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2026,7 +2091,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2044,7 +2109,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2063,7 +2128,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2080,8 +2145,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2096,9 +2163,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2117,7 +2186,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2137,7 +2206,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() @@ -2286,6 +2355,21 @@ def set(self, x=None, other=None, other_x=None): self.assertEqual(b._x, 2) +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -2335,15 +2419,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e5d332ce5442..e95a9b523393 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,11 @@ # limitations under the License. # +import re +import sys +import inspect +from py4j.protocol import Py4JJavaError + __all__ = [] @@ -33,13 +38,81 @@ def _exception_message(excp): >>> msg == _exception_message(excp) True """ + if isinstance(excp, Py4JJavaError): + # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message' + # attribute in Python 2. We should call 'str' function on this exception in general but + # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work + # around by the direct call, '__str__()'. Please see SPARK-23517. + return excp.__str__() if hasattr(excp, "message"): return excp.message return str(excp) +def _get_argspec(f): + """ + Get argspec of a function. Supports both Python 2 and Python 3. + """ + + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here + argspec = f._argspec + elif sys.version_info[0] < 3: + argspec = inspect.getargspec(f) + else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + argspec = inspect.getfullargspec(f) + return argspec + + +class VersionUtils(object): + """ + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). + + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") + + +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 12dd53b9d290..b9c2c4ced71d 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.3.0.dev0" +__version__ = "2.4.0.dev0" diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e6737ae1c128..a1a4336b1e8d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import _get_argspec from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,7 +75,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_scalar_udf(f, return_type): +def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -90,11 +91,17 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_pandas_group_map_udf(f, return_type): - def wrapped(*series): +def wrap_grouped_map_pandas_udf(f, return_type): + def wrapped(key_series, value_series): import pandas as pd + argspec = _get_argspec(f) + + if len(argspec.args) == 1: + result = f(pd.concat(value_series, axis=1)) + elif len(argspec.args) == 2: + key = tuple(s[0] for s in key_series) + result = f(key, pd.concat(value_series, axis=1)) - result = f(pd.concat(series, axis=1)) if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "pandas.DataFrame, but is {}".format(type(result))) @@ -110,6 +117,17 @@ def wrapped(*series): return wrapped +def wrap_grouped_agg_pandas_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -122,34 +140,57 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: - return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - else: + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] - for i in range(num_udfs): + mapper_str = "" + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # Create function like this: + # lambda a: f([a[0]], [a[0], a[1]]) + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs['f'] = udf + split_offset = arg_offsets[0] + 1 + arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] + arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + else: + # Create function like this: + # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: @@ -164,7 +205,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - exit(-1) + sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: @@ -238,7 +279,7 @@ def process(): # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - exit(-1) + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -256,7 +297,7 @@ def process(): else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - exit(-1) + sys.exit(-1) if __name__ == '__main__': diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage new file mode 100755 index 000000000000..6d74b563e914 --- /dev/null +++ b/python/run-tests-with-coverage @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -o pipefail +set -e + +# This variable indicates which coverage executable to run to combine coverages +# and generate HTMLs, for example, 'coverage3' in Python 3. +COV_EXEC="${COV_EXEC:-coverage}" +FWDIR="$(cd "`dirname $0`"; pwd)" +pushd "$FWDIR" > /dev/null + +# Ensure that coverage executable is installed. +if ! hash $COV_EXEC 2>/dev/null; then + echo "Missing coverage executable in your path, skipping PySpark coverage" + exit 1 +fi + +# Set up the directories for coverage results. +export COVERAGE_DIR="$FWDIR/test_coverage" +rm -fr "$COVERAGE_DIR/coverage_data" +rm -fr "$COVERAGE_DIR/htmlcov" +mkdir -p "$COVERAGE_DIR/coverage_data" + +# Current directory are added in the python path so that it doesn't refer our built +# pyspark zip library first. +export PYTHONPATH="$FWDIR:$PYTHONPATH" +# Also, our sitecustomize.py and coverage_daemon.py are included in the path. +export PYTHONPATH="$COVERAGE_DIR:$PYTHONPATH" + +# We use 'spark.python.daemon.module' configuration to insert the coverage supported workers. +export SPARK_CONF_DIR="$COVERAGE_DIR/conf" + +# This environment variable enables the coverage. +export COVERAGE_PROCESS_START="$FWDIR/.coveragerc" + +# If you'd like to run a specific unittest class, you could do such as +# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests +./run-tests "$@" + +# Don't run coverage for the coverage command itself +unset COVERAGE_PROCESS_START + +# Coverage could generate empty coverage data files. Remove it to get rid of warnings when combining. +find $COVERAGE_DIR/coverage_data -size 0 -print0 | xargs -0 rm +echo "Combining collected coverage data under $COVERAGE_DIR/coverage_data" +$COV_EXEC combine +echo "Reporting the coverage data at $COVERAGE_DIR/coverage_data/coverage" +$COV_EXEC report --include "pyspark/*" +echo "Generating HTML files for PySpark coverage under $COVERAGE_DIR/htmlcov" +$COV_EXEC html --ignore-errors --include "pyspark/*" --directory "$COVERAGE_DIR/htmlcov" + +popd diff --git a/python/run-tests.py b/python/run-tests.py index 1341086f02db..4c90926cfa35 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,15 +22,19 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: import queue as Queue +from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -39,7 +43,7 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) from sparktestsupport.shellutils import which, subprocess_check_output # noqa -from sparktestsupport.modules import all_modules # noqa +from sparktestsupport.modules import all_modules, pyspark_sql # noqa python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') @@ -49,6 +53,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -65,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -74,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -81,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -108,8 +131,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -151,6 +200,19 @@ def parse_opts(): return opts +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) + + def main(): opts = parse_opts() if (opts.verbose): @@ -175,6 +237,11 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() @@ -191,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -198,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() @@ -215,6 +287,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() diff --git a/python/setup.py b/python/setup.py index 251d4526d4dd..794ceceae300 100644 --- a/python/setup.py +++ b/python/setup.py @@ -26,7 +26,7 @@ if sys.version_info < (2, 7): print("Python versions prior to 2.7 are not supported for pip installed PySpark.", file=sys.stderr) - exit(-1) + sys.exit(-1) try: exec(open('pyspark/version.py').read()) @@ -98,7 +98,12 @@ def _supports_symlinks(): except: print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), file=sys.stderr) - exit(-1) + sys.exit(-1) + +# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and +# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. +_minimum_pandas_version = "0.19.2" +_minimum_pyarrow_version = "0.8.0" try: # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts @@ -135,7 +140,7 @@ def _supports_symlinks(): if not os.path.isdir(SCRIPTS_TARGET): print(incorrect_invocation_message, file=sys.stderr) - exit(-1) + sys.exit(-1) # Scripts directive requires a list of each script path and does not take wild cards. script_names = os.listdir(SCRIPTS_TARGET) @@ -201,7 +206,10 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] + 'sql': [ + 'pandas>=%s' % _minimum_pandas_version, + 'pyarrow>=%s' % _minimum_pyarrow_version, + ] }, classifiers=[ 'Development Status :: 5 - Production/Stable', diff --git a/python/test_coverage/conf/spark-defaults.conf b/python/test_coverage/conf/spark-defaults.conf new file mode 100644 index 000000000000..bf44ea6e7cfe --- /dev/null +++ b/python/test_coverage/conf/spark-defaults.conf @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This is used to generate PySpark coverage results. Seems there's no way to +# add a configuration when SPARK_TESTING environment variable is set because +# we will directly execute modules by python -m. +spark.python.daemon.module coverage_daemon diff --git a/python/test_coverage/coverage_daemon.py b/python/test_coverage/coverage_daemon.py new file mode 100644 index 000000000000..c87366a1ac23 --- /dev/null +++ b/python/test_coverage/coverage_daemon.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import imp + + +# This is a hack to always refer the main code rather than built zip. +main_code_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +daemon = imp.load_source("daemon", "%s/pyspark/daemon.py" % main_code_dir) + +if "COVERAGE_PROCESS_START" in os.environ: + worker = imp.load_source("worker", "%s/pyspark/worker.py" % main_code_dir) + + def _cov_wrapped(*args, **kwargs): + import coverage + cov = coverage.coverage( + config_file=os.environ["COVERAGE_PROCESS_START"]) + cov.start() + try: + worker.main(*args, **kwargs) + finally: + cov.stop() + cov.save() + daemon.worker_main = _cov_wrapped +else: + raise RuntimeError("COVERAGE_PROCESS_START environment variable is not set, exiting.") + + +if __name__ == '__main__': + daemon.manager() diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/python/test_coverage/sitecustomize.py similarity index 64% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile rename to python/test_coverage/sitecustomize.py index 055493188fcb..630237a51812 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/python/test_coverage/sitecustomize.py @@ -15,10 +15,9 @@ # limitations under the License. # -FROM spark-base - -# If this docker file is being used in the context of building your images from a Spark distribution, the docker build -# command should be invoked from the top level directory of the Spark distribution. E.g.: -# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . - -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +# Note that this 'sitecustomize' module is a built-in feature in Python. +# If this module is defined, it's executed when the Python session begins. +# `coverage.process_startup()` seeks if COVERAGE_PROCESS_START environment +# variable is set or not. If set, it starts to run the coverage. +import coverage +coverage.process_startup() diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 000000000000..9c657fa30ac9 Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ diff --git a/repl/pom.xml b/repl/pom.xml index 1cb0098d0eca..6f4a863c48bc 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 127f67329f26..4dc399827ffe 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,12 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException} -import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream} +import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels -import scala.util.control.NonFatal - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ @@ -30,13 +28,13 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.util.ParentClassLoader /** - * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, used to load classes - * defined by the interpreter when the REPL is used. Allows the user to specify if user class path - * should be first. This class loader delegates getting/finding resources to parent loader, which - * makes sense until REPL never provide resource dynamically. + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. * * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or * the load failed, that it will call the overridden `findClass` function. To avoid the potential @@ -60,7 +58,6 @@ class ExecutorClassLoader( private val fetchFn: (String) => InputStream = uri.getScheme() match { case "spark" => getClassFileInputStreamFromSparkRPC - case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer case _ => val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) getClassFileInputStreamFromFileSystem(fileSystem) @@ -113,42 +110,6 @@ class ExecutorClassLoader( } } - private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), - SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] - // Set the connection timeouts (for testing purposes) - if (httpUrlConnectionTimeoutMillis != -1) { - connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) - connection.setReadTimeout(httpUrlConnectionTimeoutMillis) - } - connection.connect() - try { - if (connection.getResponseCode != 200) { - // Close the error stream so that the connection is eligible for re-use - try { - connection.getErrorStream.close() - } catch { - case ioe: IOException => - logError("Exception while closing error stream", ioe) - } - throw new ClassNotFoundException(s"Class file not found at URL $url") - } else { - connection.getInputStream - } - } catch { - case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => - connection.disconnect() - throw e - } - } - private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index ec3d790255ad..d49e0fd85229 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -350,7 +350,7 @@ class SingletonReplSuite extends SparkFunSuite { """ |val timeout = 60000 // 60 seconds |val start = System.currentTimeMillis - |while(sc.getExecutorStorageStatus.size != 3 && + |while(sc.statusTracker.getExecutorInfos.size != 3 && | (System.currentTimeMillis - start) < timeout) { | Thread.sleep(10) |} @@ -361,11 +361,11 @@ class SingletonReplSuite extends SparkFunSuite { |case class Foo(i: Int) |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) |ret.count() - |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + |val res = sc.getRDDStorageInfo.filter(_.id == ret.id).map(_.numCachedPartitions).sum """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res: Int = 20", output) + assertContains("res: Int = 10", output) } test("should clone and clean line object in ClosureCleaner") { diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 7d35aea8a414..a62f27127346 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e5d79d9a9d9d..4086970ffb25 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -29,17 +29,23 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("default") + val CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.container.image") + .doc("Container image to use for Spark containers. Individual container types " + + "(e.g. driver or executor) can also be configured to use different images if desired, " + + "by setting the container type-specific image name.") + .stringConf + .createOptional + val DRIVER_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.driver.container.image") .doc("Container image to use for the driver.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val EXECUTOR_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.executor.container.image") .doc("Container image to use for the executors.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val CONTAINER_IMAGE_PULL_POLICY = ConfigBuilder("spark.kubernetes.container.image.pullPolicy") @@ -48,6 +54,13 @@ private[spark] object Config extends Logging { .checkValues(Set("Always", "Never", "IfNotPresent")) .createWithDefault("IfNotPresent") + val IMAGE_PULL_SECRETS = + ConfigBuilder("spark.kubernetes.container.image.pullSecrets") + .doc("Comma separated list of the Kubernetes secrets used " + + "to access private image registries.") + .stringConf + .createOptional + val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = @@ -73,12 +86,24 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_DRIVER_SUBMIT_CHECK = + ConfigBuilder("spark.kubernetes.submitInDriver") + .internal() + .booleanConf + .createOptional + val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") .doc("Specify the hard cpu limit for each executor pod") .stringConf .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = + ConfigBuilder("spark.kubernetes.executor.request.cores") + .doc("Specify the cpu request for each executor pod") + .stringConf + .createOptional + val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") .doc("Name of the driver pod.") @@ -129,74 +154,6 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") - val JARS_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pod.") - .stringConf - .createWithDefault("/var/spark-data/spark-jars") - - val FILES_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pods.") - .stringConf - .createWithDefault("/var/spark-data/spark-files") - - val INIT_CONTAINER_IMAGE = - ConfigBuilder("spark.kubernetes.initContainer.image") - .doc("Image for the driver and executor's init-container for downloading dependencies.") - .stringConf - .createOptional - - val INIT_CONTAINER_MOUNT_TIMEOUT = - ConfigBuilder("spark.kubernetes.mountDependencies.timeout") - .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + - "locations into the driver and executor pods.") - .timeConf(TimeUnit.SECONDS) - .createWithDefault(300) - - val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = - ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") - .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + - "executor pod.") - .intConf - .createWithDefault(5) - - val INIT_CONTAINER_REMOTE_JARS = - ConfigBuilder("spark.kubernetes.initContainer.remoteJars") - .doc("Comma-separated list of jar URIs to download in the init-container. This is " + - "calculated from spark.jars.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_REMOTE_FILES = - ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") - .doc("Comma-separated list of file URIs to download in the init-container. This is " + - "calculated from spark.files.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_NAME = - ConfigBuilder("spark.kubernetes.initContainer.configMapName") - .doc("Name of the config map to use in the init-container that retrieves submitted files " + - "for the executor.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = - ConfigBuilder("spark.kubernetes.initContainer.configMapKey") - .doc("Key for the entry in the init container config map for submitted files that " + - "corresponds to the properties for this init-container.") - .internal() - .stringConf - .createOptional - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -210,5 +167,5 @@ private[spark] object Config extends Logging { val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." - val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 111cb2a3b75e..8da5f24044aa 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -60,26 +60,16 @@ private[spark] object Constants { val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" - val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" - val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" - val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" - val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" + val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" - val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" - val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" - - // Bootstrapping dependencies with the init-container - val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" - val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" - val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" - val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" - val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" - val INIT_CONTAINER_PROPERTIES_FILE_PATH = - s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" - val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" + val ENV_SPARK_CONF_DIR = "SPARK_CONF_DIR" + // Spark app configs for containers + val SPARK_CONF_VOLUME = "spark-conf-volume" + val SPARK_CONF_DIR_INTERNAL = "/opt/spark/conf" + val SPARK_CONF_FILE_NAME = "spark.properties" + val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala deleted file mode 100644 index dfeccf9e2bd1..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Bootstraps an init-container for downloading remote dependencies. This is separated out from - * the init-container steps API because this component can be used to bootstrap init-containers - * for both the driver and executors. - */ -private[spark] class InitContainerBootstrap( - initContainerImage: String, - imagePullPolicy: String, - jarsDownloadPath: String, - filesDownloadPath: String, - configMapName: String, - configMapKey: String, - sparkRole: String, - sparkConf: SparkConf) { - - /** - * Bootstraps an init-container that downloads dependencies to be used by a main container. - */ - def bootstrapInitContainer( - original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { - val sharedVolumeMounts = Seq[VolumeMount]( - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withMountPath(jarsDownloadPath) - .build(), - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withMountPath(filesDownloadPath) - .build()) - - val customEnvVarKeyPrefix = sparkRole match { - case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY - case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." - case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") - } - val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { - case (key, value) => - new EnvVarBuilder() - .withName(key) - .withValue(value) - .build() - } - - val initContainer = new ContainerBuilder(original.initContainer) - .withName("spark-init") - .withImage(initContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(customEnvVars.asJava) - .addNewVolumeMount() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) - .endVolumeMount() - .addToVolumeMounts(sharedVolumeMounts: _*) - .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) - .build() - - val podWithBasicVolumes = new PodBuilder(original.pod) - .editSpec() - .addNewVolume() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withNewConfigMap() - .withName(configMapName) - .addNewItem() - .withKey(configMapKey) - .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .endSpec() - .build() - - val mainContainer = new ContainerBuilder(original.mainContainer) - .addToVolumeMounts(sharedVolumeMounts: _*) - .addNewEnv() - .withName(ENV_MOUNTED_FILES_DIR) - .withValue(filesDownloadPath) - .endEnv() - .build() - - PodWithDetachedInitContainer( - podWithBasicVolumes, - initContainer, - mainContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala new file mode 100644 index 000000000000..77b634ddfabc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.internal.config.ConfigEntry + +private[spark] sealed trait KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark driver. + */ +private[spark] case class KubernetesDriverSpecificConf( + mainAppResource: Option[MainAppResource], + mainClass: String, + appName: String, + appArgs: Seq[String]) extends KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark executor. + */ +private[spark] case class KubernetesExecutorSpecificConf( + executorId: String, + driverPod: Pod) + extends KubernetesRoleSpecificConf + +/** + * Structure containing metadata for Kubernetes logic to build Spark pods. + */ +private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( + sparkConf: SparkConf, + roleSpecificConf: T, + appResourceNamePrefix: String, + appId: String, + roleLabels: Map[String, String], + roleAnnotations: Map[String, String], + roleSecretNamesToMountPaths: Map[String, String], + roleEnvs: Map[String, String]) { + + def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + + def sparkJars(): Seq[String] = sparkConf + .getOption("spark.jars") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def sparkFiles(): Seq[String] = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets(): Seq[LocalObjectReference] = { + sparkConf + .get(IMAGE_PULL_SECRETS) + .map(_.split(",")) + .getOrElse(Array.empty[String]) + .map(_.trim) + .map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + } + + def nodeSelector(): Map[String, String] = + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) + + def get(conf: String): String = sparkConf.get(conf) + + def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue) + + def getOption(key: String): Option[String] = sparkConf.getOption(key) +} + +private[spark] object KubernetesConf { + def createDriverConf( + sparkConf: SparkConf, + appName: String, + appResourceNamePrefix: String, + appId: String, + mainAppResource: Option[MainAppResource], + mainClass: String, + appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + val sparkConfWithMainAppJar = sparkConf.clone() + mainAppResource.foreach { + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + } + + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + val driverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + + KubernetesConf( + sparkConfWithMainAppJar, + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + appResourceNamePrefix, + appId, + driverLabels, + driverAnnotations, + driverSecretNamesToMountPaths, + driverEnvs) + } + + def createExecutorConf( + sparkConf: SparkConf, + executorId: String, + appId: String, + driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorCustomLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorCustomLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + val executorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorCustomLabels + val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnv = sparkConf.getExecutorEnv.toMap + + KubernetesConf( + sparkConf.clone(), + KubernetesExecutorSpecificConf(executorId, driverPod), + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appId, + executorLabels, + executorAnnotations, + executorSecrets, + executorEnv) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala new file mode 100644 index 000000000000..0c5ae022f407 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.HasMetadata + +private[spark] case class KubernetesDriverSpec( + pod: SparkPod, + driverKubernetesResources: Seq[HasMetadata], + systemProperties: Map[String, String]) + +private[spark] object KubernetesDriverSpec { + def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( + SparkPod.initialPod(), + Seq.empty, + initialProps) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 37331d8bbf9b..ee629068ad90 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,9 +16,7 @@ */ package org.apache.spark.deploy.k8s -import java.io.File - -import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} +import io.fabric8.kubernetes.api.model.LocalObjectReference import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -43,72 +41,23 @@ private[spark] object KubernetesUtils { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } - /** - * Append the given init-container to a pod's list of init-containers. - * - * @param originalPodSpec original specification of the pod - * @param initContainer the init-container to add to the pod - * @return the pod with the init-container added to the list of InitContainers - */ - def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { - new PodBuilder(originalPodSpec) - .editOrNewSpec() - .addToInitContainers(initContainer) - .endSpec() - .build() - } - /** * For the given collection of file URIs, resolves them as follows: - * - File URIs with scheme file:// are resolved to the given download path. * - File URIs with scheme local:// resolve to just the path of the URI. * - Otherwise, the URIs are returned as-is. */ - def resolveFileUris( - fileUris: Iterable[String], - fileDownloadPath: String): Iterable[String] = { + def resolveFileUrisAndPath(fileUris: Iterable[String]): Iterable[String] = { fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, false) - } - } - - /** - * If any file uri has any scheme other than local:// it is mapped as if the file - * was downloaded to the file download path. Otherwise, it is mapped to the path - * part of the URI. - */ - def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = { - fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, true) - } - } - - /** - * Get from a given collection of file URIs the ones that represent remote files. - */ - def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { - uris.filter { uri => - val scheme = Utils.resolveURI(uri).getScheme - scheme != "file" && scheme != "local" + resolveFileUri(uri) } } - private def resolveFileUri( - uri: String, - fileDownloadPath: String, - assumesDownloaded: Boolean): String = { + private def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { - case "local" => - fileUri.getPath - case _ => - if (assumesDownloaded || fileScheme == "file") { - val fileName = new File(fileUri.getPath).getName - s"$fileDownloadPath/$fileName" - } else { - uri - } + case "local" => fileUri.getPath + case _ => uri } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala deleted file mode 100644 index 8286546ce064..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} - -/** - * Bootstraps a driver or executor container or an init-container with needed secrets mounted. - */ -private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { - - /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. - * - * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. - */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { - var podBuilder = new PodBuilder(pod) - secretNamesToMountPaths.keys.foreach { name => - podBuilder = podBuilder - .editOrNewSpec() - .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() - .endSpec() - } - - var containerBuilder = new ContainerBuilder(container) - secretNamesToMountPaths.foreach { case (name, path) => - containerBuilder = containerBuilder - .addNewVolumeMount() - .withName(secretVolumeName(name)) - .withMountPath(path) - .endVolumeMount() - } - - (podBuilder.build(), containerBuilder.build()) - } - - private def secretVolumeName(secretName: String): String = { - secretName + "-volume" - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala similarity index 67% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 0b79f8b12e80..345dd117fd35 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -16,16 +16,19 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.{Container, Pod} +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -/** - * Represents a pod with a detached init-container (not yet added to the pod). - * - * @param pod the pod - * @param initContainer the init-container in the pod - * @param mainContainer the main container in the pod - */ -private[spark] case class PodWithDetachedInitContainer( - pod: Pod, - initContainer: Container, - mainContainer: Container) +private[spark] case class SparkPod(pod: Pod, container: Container) + +private[spark] object SparkPod { + def initialPod(): SparkPod = { + SparkPod( + new PodBuilder() + .withNewMetadata() + .endMetadata() + .withNewSpec() + .endSpec() + .build(), + new ContainerBuilder().build()) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala new file mode 100644 index 000000000000..07bdccbe0479 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class BasicDriverFeatureStep( + conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"${conf.appResourceNamePrefix}-driver") + + private val driverContainerImage = conf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) + + // CPU settings + private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadMiB = conf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configurePod(pod: SparkPod): SparkPod = { + val driverCustomEnvs = conf.roleEnvs + .toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(pod.container) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverContainerImage) + .withImagePullPolicy(conf.imagePullPolicy()) + .addAllToEnv(driverCustomEnvs.asJava) + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryQuantity) + .endResources() + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", conf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(conf.roleSpecificConf.appArgs: _*) + .build() + + val driverPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(driverPodName) + .addToLabels(conf.roleLabels.asJava) + .addToAnnotations(conf.roleAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(conf.nodeSelector().asJava) + .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(driverPod, driverContainer) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val additionalProps = mutable.Map( + KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, + "spark.app.id" -> conf.appId, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") + + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkJars()) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkFiles()) + if (resolvedSparkJars.nonEmpty) { + additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + } + additionalProps.toMap + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala new file mode 100644 index 000000000000..529069d3b8a0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) + extends KubernetesFeatureConfigStep { + + // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf + private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) + private val executorContainerImage = kubernetesConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val blockManagerPort = kubernetesConf + .sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + + private val driverUrl = RpcEndpointAddress( + kubernetesConf.get("spark.driver.host"), + kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) + private val executorMemoryString = kubernetesConf.get( + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = kubernetesConf + .get(EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = + if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } + private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def configurePod(pod: SparkPod): SparkPod = { + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCoresRequest) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = kubernetesConf + .get(EXECUTOR_JAVA_OPTIONS) + .map { opts => + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ + kubernetesConf.roleEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder(pod.container) + .withName("executor") + .withImage(executorContainerImage) + .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .addToArgs("executor") + .build() + val containerWithLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + val driverPod = kubernetesConf.roleSpecificConf.driverPod + val executorPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(name) + .withLabels(kubernetesConf.roleLabels.asJava) + .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .editOrNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(executorPod, containerWithLimitCores) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala new file mode 100644 index 000000000000..ff5ad6673b30 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) + extends KubernetesFeatureConfigStep { + // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. + // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. + + private val maybeMountedOAuthTokenFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + private val oauthTokenBase64 = kubernetesConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + + private val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + private val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + private val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder. + private val shouldMountSecret = oauthTokenBase64.isDefined || + caCertDataBase64.isDefined || + clientKeyDataBase64.isDefined || + clientCertDataBase64.isDefined + + private val driverCredentialsSecretName = + s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + + override def configurePod(pod: SparkPod): SparkPod = { + if (!shouldMountSecret) { + pod.copy( + pod = driverServiceAccount.map { account => + new PodBuilder(pod.pod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(pod.pod)) + } else { + val driverPodWithMountedKubernetesCredentials = + new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret() + .endVolume() + .endSpec() + .build() + + val driverContainerWithMountedSecretVolume = + new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val redactedTokens = kubernetesConf.sparkConf.getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) + .toMap + .mapValues( _ => "") + redactedTokens ++ + resolvedMountedCaCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientKeyFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedOAuthTokenFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (shouldMountSecret) { + Seq(createCredentialsSecret()) + } else { + Seq.empty + } + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + kubernetesConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + private def createCredentialsSecret(): Secret = { + val allSecretData = + resolveSecretData( + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + new SecretBuilder() + .withNewMetadata() + .withName(driverCredentialsSecretName) + .endMetadata() + .withData(allSecretData.asJava) + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala new file mode 100644 index 000000000000..f2d7bbd08f30 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, SystemClock} + +private[spark] class DriverServiceFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + clock: Clock = new SystemClock) + extends KubernetesFeatureConfigStep with Logging { + import DriverServiceFeatureStep._ + + require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + private val driverPort = kubernetesConf.sparkConf.getInt( + "spark.driver.port", DEFAULT_DRIVER_PORT) + private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + + override def configurePod(pod: SparkPod): SparkPod = pod + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + Map(DRIVER_HOST_KEY -> driverHostname, + "spark.driver.port" -> driverPort.toString, + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> + driverBlockManagerPort.toString) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(kubernetesConf.roleLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + Seq(driverService) + } +} + +private[spark] object DriverServiceFeatureStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala new file mode 100644 index 000000000000..4c1be3bb1329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.HasMetadata + +import org.apache.spark.deploy.k8s.SparkPod + +/** + * A collection of functions that together represent a "feature" in pods that are launched for + * Spark drivers and executors. + */ +private[spark] trait KubernetesFeatureConfigStep { + + /** + * Apply modifications on the given pod in accordance to this feature. This can include attaching + * volumes, adding environment variables, and adding labels/annotations. + *

      + * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod + * object. So this is correct: + *

      +   * {@code val configuredPod = new PodBuilder(pod.pod)
      +   *     .editSpec()
      +   *     ...
      +   *     .build()
      +   *   val configuredContainer = new ContainerBuilder(pod.container)
      +   *     ...
      +   *     .build()
      +   *   SparkPod(configuredPod, configuredContainer)
      +   *  }
      +   * 
      + * This is incorrect: + *
      +   * {@code val configuredPod = new PodBuilder() // Loses the original state
      +   *     .editSpec()
      +   *     ...
      +   *     .build()
      +   *   val configuredContainer = new ContainerBuilder() // Loses the original state
      +   *     ...
      +   *     .build()
      +   *   SparkPod(configuredPod, configuredContainer)
      +   *  }
      +   * 
      + */ + def configurePod(pod: SparkPod): SparkPod + + /** + * Return any system properties that should be set on the JVM in accordance to this feature. + */ + def getAdditionalPodSystemProperties(): Map[String, String] + + /** + * Return any additional Kubernetes resources that should be added to support this feature. Only + * applicable when creating the driver in cluster mode. + */ + def getAdditionalKubernetesResources(): Seq[HasMetadata] +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala new file mode 100644 index 000000000000..97fa9499b2ed --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class MountSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedVolumes = kubernetesConf + .roleSecretNamesToMountPaths + .keys + .map(secretName => + new VolumeBuilder() + .withName(secretVolumeName(secretName)) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .build()) + val podWithVolumes = new PodBuilder(pod.pod) + .editOrNewSpec() + .addToVolumes(addedVolumes.toSeq: _*) + .endSpec() + .build() + val addedVolumeMounts = kubernetesConf + .roleSecretNamesToMountPaths + .map { + case (secretName, mountPath) => + new VolumeMountBuilder() + .withName(secretVolumeName(secretName)) + .withMountPath(mountPath) + .build() + } + val containerWithMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(addedVolumeMounts.toSeq: _*) + .build() + SparkPod(podWithVolumes, containerWithMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def secretVolumeName(secretName: String): String = s"$secretName-volume" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala deleted file mode 100644 index 00c9c4ee4917..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import java.util.UUID - -import com.google.common.primitives.Longs - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator -import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils - -/** - * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to - * configure the Spark driver pod. The returned steps will be applied one by one in the given - * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to - * configure the driver init-container if one is needed, i.e., when there are remote dependencies - * to localize. - */ -private[spark] class DriverConfigOrchestrator( - kubernetesAppId: String, - launchTime: Long, - mainAppResource: Option[MainAppResource], - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) { - - // The resource name prefix is derived from the Spark application name, making it easy to connect - // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the - // application the user submitted. - private val kubernetesResourceNamePrefix = { - val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "") - s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") - } - - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" - private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) - - def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - - val allDriverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> kubernetesAppId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - - val initialSubmissionStep = new BasicDriverConfigurationStep( - kubernetesAppId, - kubernetesResourceNamePrefix, - allDriverLabels, - imagePullPolicy, - appName, - mainClass, - appArgs, - sparkConf) - - val serviceBootstrapStep = new DriverServiceBootstrapStep( - kubernetesResourceNamePrefix, - allDriverLabels, - sparkConf, - new SystemClock) - - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - sparkConf, kubernetesResourceNamePrefix) - - val additionalMainAppJar = if (mainAppResource.nonEmpty) { - val mayBeResource = mainAppResource.get match { - case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => - Some(resource) - case _ => None - } - mayBeResource - } else { - None - } - - val sparkJars = sparkConf.getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty[String]) ++ - additionalMainAppJar.toSeq - val sparkFiles = sparkConf.getOption("spark.files") - .map(_.split(",")) - .getOrElse(Array.empty[String]) - - val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Seq(new DependencyResolutionStep( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath)) - } else { - Nil - } - - val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { - val orchestrator = new InitContainerConfigOrchestrator( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - imagePullPolicy, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME, - sparkConf) - val bootstrapStep = new DriverInitContainerBootstrapStep( - orchestrator.getAllConfigurationSteps, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME) - - Seq(bootstrapStep) - } else { - Nil - } - - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq( - initialSubmissionStep, - serviceBootstrapStep, - kubernetesCredentialsStep) ++ - dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep - } - - private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme != "local" - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 5884348cb3e4..a97f5650fb86 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -16,21 +16,20 @@ */ package org.apache.spark.deploy.k8s.submit +import java.io.StringWriter import java.util.{Collections, UUID} - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.util.control.NonFatal +import java.util.Properties import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -42,9 +41,9 @@ import org.apache.spark.util.Utils * @param driverArgs arguments to the driver */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], - mainClass: String, - driverArgs: Array[String]) + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) private[spark] object ClientArguments { @@ -79,8 +78,9 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * - * @param submissionSteps steps that collectively configure the driver - * @param sparkConf the submission client Spark configuration + * @param builder Responsible for building the base driver pod based on a composition of + * implemented features. + * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete @@ -88,55 +88,41 @@ private[spark] object ClientArguments { * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( - submissionSteps: Seq[DriverConfigurationStep], - sparkConf: SparkConf, + builder: KubernetesDriverBuilder, + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - watcher: LoggingPodStatusWatcher) extends Logging { - - private val driverJavaOptions = sparkConf.get( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + watcher: LoggingPodStatusWatcher, + kubernetesResourceNamePrefix: String) extends Logging { - /** - * Run command that initializes a DriverSpec that will be updated after each - * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec - * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources - */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) - // submissionSteps contain steps necessary to take, to resolve varying - // client arguments that are passed in, created by orchestrator - for (nextStep <- submissionSteps) { - currentDriverSpec = nextStep.configureDriver(currentDriverSpec) - } - - val resolvedDriverJavaOpts = currentDriverSpec - .driverSparkConf - // Remove this as the options are instead extracted and set individually below using - // environment variables with prefix SPARK_JAVA_OPT_. - .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) - .getAll - .map { - case (confKey, confValue) => s"-D$confKey=$confValue" - } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) - val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map { - case (option, index) => - new EnvVarBuilder() - .withName(s"$ENV_JAVA_OPT_PREFIX$index") - .withValue(option) - .build() - } - - val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) - .addAllToEnv(driverJavaOptsEnvs.asJava) + val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) + val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" + val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) + // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the + // Spark command builder to pickup on the Java Options present in the ConfigMap + val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() .build() - val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod) .editSpec() .addToContainers(resolvedDriverContainer) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .endConfigMap() + .endVolume() .endSpec() .build() - Utils.tryWithResource( kubernetesClient .pods() @@ -144,11 +130,10 @@ private[spark] class Client( .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { - if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = currentDriverSpec.otherKubernetesResources - addDriverOwnerReference(createdDriverPod, otherKubernetesResources) - kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() - } + val otherKubernetesResources = + resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap) + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } catch { case NonFatal(e) => kubernetesClient.pods().delete(createdDriverPod) @@ -180,6 +165,23 @@ private[spark] class Client( originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference)) } } + + // Build a Config Map that will house spark conf properties in a single file for spark-submit + private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = { + val properties = new Properties() + conf.foreach { case (k, v) => + properties.setProperty(k, v) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName") + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .endMetadata() + .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) + .build() + } } /** @@ -193,7 +195,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // For constructing the app ID, we can't use the Spark application name, as the app ID is going // to be added as a label to group resources belonging to the same application. Label values are // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate @@ -201,7 +203,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") + val kubernetesResourceNamePrefix = { + s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") + } + val kubernetesConf = KubernetesConf.createDriverConf( + sparkConf, + appName, + kubernetesResourceNamePrefix, + kubernetesAppId, + clientArguments.mainAppResource, + clientArguments.mainClass, + clientArguments.driverArgs) + val builder = new KubernetesDriverBuilder + val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -209,15 +223,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val orchestrator = new DriverConfigOrchestrator( - kubernetesAppId, - launchTime, - clientArguments.mainAppResource, - appName, - clientArguments.mainClass, - clientArguments.driverArgs, - sparkConf) - Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, Some(namespace), @@ -226,12 +231,13 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - orchestrator.getAllConfigurationSteps, - sparkConf, + builder, + kubernetesConf, kubernetesClient, waitForAppCompletion, appName, - watcher) + watcher, + kubernetesResourceNamePrefix) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala new file mode 100644 index 000000000000..c7579ed8cb68 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesDriverBuilder( + provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + new BasicDriverFeatureStep(_), + provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) + => DriverKubernetesCredentialsFeatureStep = + new DriverKubernetesCredentialsFeatureStep(_), + provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + new DriverServiceFeatureStep(_), + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountSecretsFeatureStep) = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideCredentialsStep(kubernetesConf), + provideServiceStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + for (feature <- allFeatures) { + val configuredPod = feature.configurePod(spec.pod) + val addedSystemProperties = feature.getAdditionalPodSystemProperties() + val addedResources = feature.getAdditionalKubernetesResources() + spec = KubernetesDriverSpec( + configuredPod, + spec.driverKubernetesResources ++ addedResources, + spec.systemProperties ++ addedSystemProperties) + } + spec + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala deleted file mode 100644 index db13f09387ef..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} - -import org.apache.spark.SparkConf - -/** - * Represents the components and characteristics of a Spark driver. The driver can be considered - * as being comprised of the driver pod itself, any other Kubernetes resources that the driver - * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver - * container should be operated on via the specific field of this case class as opposed to trying - * to edit the container directly on the pod. The driver container should be attached at the - * end of executing all submission steps. - */ -private[spark] case class KubernetesDriverSpec( - driverPod: Pod, - driverContainer: Container, - otherKubernetesResources: Seq[HasMetadata], - driverSparkConf: SparkConf) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { - KubernetesDriverSpec( - // Set new metadata and a new spec so that submission steps can use - // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - new ContainerBuilder().build(), - Seq.empty[HasMetadata], - initialSparkConf.clone()) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala deleted file mode 100644 index b7a69a7dfd47..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder} - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} - -/** - * Performs basic configuration for the driver pod. - */ -private[spark] class BasicDriverConfigurationStep( - kubernetesAppId: String, - resourceNamePrefix: String, - driverLabels: Map[String, String], - imagePullPolicy: String, - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) extends DriverConfigurationStep { - - private val driverPodName = sparkConf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$resourceNamePrefix-driver") - - private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - - private val driverContainerImage = sparkConf - .get(DRIVER_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver container image")) - - // CPU settings - private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) - - // Memory settings - private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val driverMemoryString = sparkConf.get( - DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) - private val memoryOverheadMiB = sparkConf - .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) - private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => - new EnvVarBuilder() - .withName(ENV_SUBMIT_EXTRA_CLASSPATH) - .withValue(classPath) - .build() - } - - val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), - s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + - " Spark bookkeeping operations.") - - val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq - .map { env => - new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - } - - val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - - val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - - val driverCpuQuantity = new QuantityBuilder(false) - .withAmount(driverCpuCores) - .build() - val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryMiB}Mi") - .build() - val driverMemoryLimitQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryWithOverheadMiB}Mi") - .build() - val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => - ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) - } - - val driverContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(DRIVER_CONTAINER_NAME) - .withImage(driverContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(driverCustomEnvs.asJava) - .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_MEMORY) - .withValue(driverMemoryString) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_MAIN_CLASS) - .withValue(mainClass) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_BIND_ADDRESS) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .endEnv() - .withNewResources() - .addToRequests("cpu", driverCpuQuantity) - .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryLimitQuantity) - .addToLimits(maybeCpuLimitQuantity.toMap.asJava) - .endResources() - .build() - - val baseDriverPod = new PodBuilder(driverSpec.driverPod) - .editOrNewMetadata() - .withName(driverPodName) - .addToLabels(driverLabels.asJava) - .addToAnnotations(driverAnnotations.asJava) - .endMetadata() - .withNewSpec() - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .endSpec() - .build() - - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) - - driverSpec.copy( - driverPod = baseDriverPod, - driverSparkConf = resolvedSparkConf, - driverContainer = driverContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala deleted file mode 100644 index d4b83235b4e3..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import io.fabric8.kubernetes.api.model.ContainerBuilder - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Step that configures the classpath, spark.jars, and spark.files for the driver given that the - * user may provide remote files or files with local:// schemes. - */ -private[spark] class DependencyResolutionStep( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) - - val sparkConf = driverSpec.driverSparkConf.clone() - if (resolvedSparkJars.nonEmpty) { - sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) - } - - val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { - new ContainerBuilder(driverSpec.driverContainer) - .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedClasspath.mkString(File.pathSeparator)) - .endEnv() - .build() - } else { - driverSpec.driverContainer - } - - driverSpec.copy( - driverContainer = resolvedDriverContainer, - driverSparkConf = sparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala deleted file mode 100644 index 9fb3dafdda54..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringWriter -import java.util.Properties - -import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} - -/** - * Configures the driver init-container that localizes remote dependencies into the driver pod. - * It applies the given InitContainerConfigurationSteps in the given order to produce a final - * InitContainerSpec that is then used to configure the driver pod with the init-container attached. - * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries - * configuration properties for the init-container. - */ -private[spark] class DriverInitContainerBootstrapStep( - steps: Seq[InitContainerConfigurationStep], - configMapName: String, - configMapKey: String) - extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - var initContainerSpec = InitContainerSpec( - properties = Map.empty[String, String], - driverSparkConf = Map.empty[String, String], - initContainer = new ContainerBuilder().build(), - driverContainer = driverSpec.driverContainer, - driverPod = driverSpec.driverPod, - dependentResources = Seq.empty[HasMetadata]) - for (nextStep <- steps) { - initContainerSpec = nextStep.configureInitContainer(initContainerSpec) - } - - val configMap = buildConfigMap( - configMapName, - configMapKey, - initContainerSpec.properties) - val resolvedDriverSparkConf = driverSpec.driverSparkConf - .clone() - .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) - .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) - .setAll(initContainerSpec.driverSparkConf) - val resolvedDriverPod = KubernetesUtils.appendInitContainer( - initContainerSpec.driverPod, initContainerSpec.initContainer) - - driverSpec.copy( - driverPod = resolvedDriverPod, - driverContainer = initContainerSpec.driverContainer, - driverSparkConf = resolvedDriverSparkConf, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ - initContainerSpec.dependentResources ++ - Seq(configMap)) - } - - private def buildConfigMap( - configMapName: String, - configMapKey: String, - config: Map[String, String]): ConfigMap = { - val properties = new Properties() - config.foreach { entry => - properties.setProperty(entry._1, entry._2) - } - val propertiesWriter = new StringWriter() - properties.store(propertiesWriter, - s"Java properties built from Kubernetes config map with name: $configMapName " + - s"and config map key: $configMapKey") - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(configMapKey, propertiesWriter.toString) - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala deleted file mode 100644 index ccc18908658f..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials - * to request executors. - */ -private[spark] class DriverKubernetesCredentialsStep( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { - - private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") - private val maybeMountedClientKeyFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") - private val maybeMountedClientCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") - private val maybeMountedCaCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") - private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverSparkConf = driverSpec.driverSparkConf.clone() - - val oauthTokenBase64 = submissionSparkConf - .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") - .map { token => - BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) - } - val caCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - "Driver CA cert file") - val clientKeyDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - "Driver client key file") - val clientCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - "Driver client cert file") - - val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( - driverSparkConf, - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val kubernetesCredentialsSecret = createCredentialsSecret( - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .addNewVolume() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() - .endVolume() - .endSpec() - .build() - }.getOrElse( - driverServiceAccount.map { account => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .withServiceAccount(account) - .withServiceAccountName(account) - .endSpec() - .build() - }.getOrElse(driverSpec.driverPod) - ) - - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret => - new ContainerBuilder(driverSpec.driverContainer) - .addNewVolumeMount() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) - .endVolumeMount() - .build() - }.getOrElse(driverSpec.driverContainer) - - driverSpec.copy( - driverPod = driverPodWithMountedKubernetesCredentials, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, - driverSparkConf = driverSparkConfWithCredentialsLocations, - driverContainer = driverContainerWithMountedSecretVolume) - } - - private def createCredentialsSecret( - driverOAuthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): Option[Secret] = { - val allSecretData = - resolveSecretData( - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ - resolveSecretData( - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ - resolveSecretData( - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ - resolveSecretData( - driverOAuthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) - - if (allSecretData.isEmpty) { - None - } else { - Some(new SecretBuilder() - .withNewMetadata() - .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") - .endMetadata() - .withData(allSecretData.asJava) - .build()) - } - } - - private def setDriverPodKubernetesCredentialLocations( - driverSparkConf: SparkConf, - driverOauthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): SparkConf = { - val resolvedMountedOAuthTokenFile = resolveSecretLocation( - maybeMountedOAuthTokenFile, - driverOauthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) - val resolvedMountedClientKeyFile = resolveSecretLocation( - maybeMountedClientKeyFile, - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_PATH) - val resolvedMountedClientCertFile = resolveSecretLocation( - maybeMountedClientCertFile, - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_PATH) - val resolvedMountedCaCertFile = resolveSecretLocation( - maybeMountedCaCertFile, - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_PATH) - - val sparkConfWithCredentialLocations = driverSparkConf - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - resolvedMountedCaCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - resolvedMountedClientKeyFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - resolvedMountedClientCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", - resolvedMountedOAuthTokenFile) - - // Redact all OAuth token values - sparkConfWithCredentialLocations - .getAll - .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) - .foreach { - sparkConfWithCredentialLocations.set(_, "") - } - sparkConfWithCredentialLocations - } - - private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { - submissionSparkConf.getOption(conf) - .map(new File(_)) - .map { file => - require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", - fileType, file.getAbsolutePath)) - BaseEncoding.base64().encode(Files.toByteArray(file)) - } - } - - private def resolveSecretLocation( - mountedUserSpecified: Option[String], - valueMountedFromSubmitter: Option[String], - mountedCanonicalLocation: String): Option[String] = { - mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => - mountedCanonicalLocation - }) - } - - /** - * Resolve a Kubernetes secret data entry from an optional client credential used by the - * driver to talk to the Kubernetes API server. - * - * @param userSpecifiedCredential the optional user-specified client credential. - * @param secretName name of the Kubernetes secret storing the client credential. - * @return a secret data entry in the form of a map from the secret name to the secret data, - * which may be empty if the user-specified credential is empty. - */ - private def resolveSecretData( - userSpecifiedCredential: Option[String], - secretName: String): Map[String, String] = { - userSpecifiedCredential.map { valueBase64 => - Map(secretName -> valueBase64) - }.getOrElse(Map.empty[String, String]) - } - - private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { - new OptionSettableSparkConf(sparkConf) - } -} - -private class OptionSettableSparkConf(sparkConf: SparkConf) { - def setOption(configEntry: String, option: Option[String]): SparkConf = { - option.foreach { opt => - sparkConf.set(configEntry, opt) - } - sparkConf - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala deleted file mode 100644 index eb594e4f16ec..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.ServiceBuilder - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.Logging -import org.apache.spark.util.Clock - -/** - * Allows the driver to be reachable by executor pods through a headless service. The service's - * ports should correspond to the ports that the executor will reach the pod at for RPC. - */ -private[spark] class DriverServiceBootstrapStep( - resourceNamePrefix: String, - driverLabels: Map[String, String], - sparkConf: SparkConf, - clock: Clock) extends DriverConfigurationStep with Logging { - - import DriverServiceBootstrapStep._ - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, - s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + - "address is managed and set to the driver pod's IP address.") - require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, - s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + - "managed via a Kubernetes service.") - - val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" - val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { - preferredServiceName - } else { - val randomServiceId = clock.getTimeMillis() - val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" - logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + - s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + - s"$shorterServiceName as the driver service's name.") - shorterServiceName - } - - val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) - val driverService = new ServiceBuilder() - .withNewMetadata() - .withName(resolvedServiceName) - .endMetadata() - .withNewSpec() - .withClusterIP("None") - .withSelector(driverLabels.asJava) - .addNewPort() - .withName(DRIVER_PORT_NAME) - .withPort(driverPort) - .withNewTargetPort(driverPort) - .endPort() - .addNewPort() - .withName(BLOCK_MANAGER_PORT_NAME) - .withPort(driverBlockManagerPort) - .withNewTargetPort(driverBlockManagerPort) - .endPort() - .endSpec() - .build() - - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .set(DRIVER_HOST_KEY, driverHostname) - .set("spark.driver.port", driverPort.toString) - .set( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) - - driverSpec.copy( - driverSparkConf = resolvedSparkConf, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) - } -} - -private[spark] object DriverServiceBootstrapStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key - val DRIVER_SVC_POSTFIX = "-driver-svc" - val MAX_SERVICE_NAME_LENGTH = 63 -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala deleted file mode 100644 index 01469853dacc..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils - -/** - * Performs basic configuration for the driver init-container with most of the work delegated to - * the given InitContainerBootstrap. - */ -private[spark] class BasicInitContainerConfigurationStep( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - bootstrap: InitContainerBootstrap) - extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { - val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) - val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) - val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) - } else { - Map() - } - val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) - } else { - Map() - } - - val baseInitContainerConfig = Map( - JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, - FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ - remoteJarsConf ++ - remoteFilesConf - - val bootstrapped = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - spec.driverPod, - spec.initContainer, - spec.driverContainer)) - - spec.copy( - initContainer = bootstrapped.initContainer, - driverContainer = bootstrapped.mainContainer, - driverPod = bootstrapped.pod, - properties = spec.properties ++ baseInitContainerConfig) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala deleted file mode 100644 index f2c29c7ce107..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to - * configure the driver init-container. The returned steps will be applied in the given order to - * produce a final InitContainerSpec that is used to construct the driver init-container in - * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., - * when there are remote application dependencies to localize. - */ -private[spark] class InitContainerConfigOrchestrator( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - imagePullPolicy: String, - configMapName: String, - configMapKey: String, - sparkConf: SparkConf) { - - private val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - - def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { - val initContainerBootstrap = new InitContainerBootstrap( - initContainerImage, - imagePullPolicy, - jarsDownloadPath, - filesDownloadPath, - configMapName, - configMapKey, - SPARK_POD_DRIVER_ROLE, - sparkConf) - val baseStep = new BasicInitContainerConfigurationStep( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - initContainerBootstrap) - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - // Mount user-specified driver secrets also into the driver's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The driver's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq(baseStep) ++ mountSecretsStep - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala deleted file mode 100644 index b52c343f0c0e..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} - -/** - * Represents a specification of the init-container for the driver pod. - * - * @param properties properties that should be set on the init-container - * @param driverSparkConf Spark configuration properties that will be carried back to the driver - * @param initContainer the init-container object - * @param driverContainer the driver container object - * @param driverPod the driver pod object - * @param dependentResources resources the init-container depends on to work - */ -private[spark] case class InitContainerSpec( - properties: Map[String, String], - driverSparkConf: Map[String, String], - initContainer: Container, - driverContainer: Container, - driverPod: Pod, - dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala deleted file mode 100644 index 4a4b628aedbb..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.rest.k8s - -import java.io.File -import java.util.concurrent.TimeUnit - -import scala.concurrent.{ExecutionContext, Future} - -import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -/** - * Process that fetches files from a resource staging server and/or arbitrary remote locations. - * - * The init-container can handle fetching files from any of those sources, but not all of the - * sources need to be specified. This allows for composing multiple instances of this container - * with different configurations for different download sources, or using the same container to - * download everything at once. - */ -private[spark] class SparkPodInitContainer( - sparkConf: SparkConf, - fileFetcher: FileFetcher) extends Logging { - - private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) - private implicit val downloadExecutor = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) - - private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) - private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) - - private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) - private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) - - private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) - - def run(): Unit = { - logInfo(s"Downloading remote jars: $remoteJars") - downloadFiles( - remoteJars, - jarsDownloadDir, - s"Remote jars download directory specified at $jarsDownloadDir does not exist " + - "or is not a directory.") - - logInfo(s"Downloading remote files: $remoteFiles") - downloadFiles( - remoteFiles, - filesDownloadDir, - s"Remote files download directory specified at $filesDownloadDir does not exist " + - "or is not a directory.") - - downloadExecutor.shutdown() - downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) - } - - private def downloadFiles( - filesCommaSeparated: Option[String], - downloadDir: File, - errMessage: String): Unit = { - filesCommaSeparated.foreach { files => - require(downloadDir.isDirectory, errMessage) - Utils.stringToSeq(files).foreach { file => - Future[Unit] { - fileFetcher.fetchFile(file, downloadDir) - } - } - } - } -} - -private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { - - def fetchFile(uri: String, targetDir: File): Unit = { - Utils.fetchFile( - url = uri, - targetDir = targetDir, - conf = sparkConf, - securityMgr = securityManager, - hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), - timestamp = System.currentTimeMillis(), - useCache = false) - } -} - -object SparkPodInitContainer extends Logging { - - def main(args: Array[String]): Unit = { - logInfo("Starting init-container to download Spark application dependencies.") - val sparkConf = new SparkConf(true) - if (args.nonEmpty) { - Utils.loadDefaultSparkProperties(sparkConf, args(0)) - } - - val securityManager = new SparkSecurityManager(sparkConf) - val fileFetcher = new FileFetcher(sparkConf, securityManager) - new SparkPodInitContainer(sparkConf, fileFetcher).run() - logInfo("Finished downloading application dependencies.") - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala deleted file mode 100644 index ba5d891f4c77..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.util.Utils - -/** - * A factory class for bootstrapping and creating executor pods with the given bootstrapping - * components. - * - * @param sparkConf Spark configuration - * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto - * user-specified paths into the executor container - * @param initContainerBootstrap an optional component for bootstrapping the executor init-container - * if one is needed, i.e., when there are remote dependencies to - * localize - * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified - * secrets onto user-specified paths into the executor - * init-container - */ -private[spark] class ExecutorPodFactory( - sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap], - initContainerBootstrap: Option[InitContainerBootstrap], - initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { - - private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - - private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - - private val executorAnnotations = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - private val nodeSelector = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_NODE_SELECTOR_PREFIX) - - private val executorContainerImage = sparkConf - .get(EXECUTOR_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor container image")) - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val blockManagerPort = sparkConf - .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - - private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - - private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) - private val executorMemoryString = sparkConf.get( - EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) - - private val memoryOverheadMiB = sparkConf - .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - - private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) - private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod = { - val name = s"$executorPodNamePrefix-exec-$executorId" - - // hostname must be no longer than 63 characters, so take the last 63 characters of the pod - // name as the hostname. This preserves uniqueness since the end of name contains - // executorId - val hostname = name.substring(Math.max(0, name.length - 63)) - val resolvedExecutorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> applicationId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorLabels - val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryMiB}Mi") - .build() - val executorMemoryLimitQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") - .build() - val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCores.toString) - .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = sparkConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() - } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - // Executor backend expects integral value for executor cores, so round it up to an int. - (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq - val requiredPorts = Seq( - (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) - .map { case (name, port) => - new ContainerPortBuilder() - .withName(name) - .withContainerPort(port) - .build() - } - - val executorContainer = new ContainerBuilder() - .withName("executor") - .withImage(executorContainerImage) - .withImagePullPolicy(imagePullPolicy) - .withNewResources() - .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryLimitQuantity) - .addToRequests("cpu", executorCpuQuantity) - .endResources() - .addAllToEnv(executorEnv.asJava) - .withPorts(requiredPorts.asJava) - .build() - - val executorPod = new PodBuilder() - .withNewMetadata() - .withName(name) - .withLabels(resolvedExecutorLabels.asJava) - .withAnnotations(executorAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withHostname(hostname) - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .endSpec() - .build() - - val containerWithLimitCores = executorLimitCores.map { limitCores => - val executorCpuLimitQuantity = new QuantityBuilder(false) - .withAmount(limitCores) - .build() - new ContainerBuilder(executorContainer) - .editResources() - .addToLimits("cpu", executorCpuLimitQuantity) - .endResources() - .build() - }.getOrElse(executorContainer) - - val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = - mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) - }.getOrElse((executorPod, containerWithLimitCores)) - - val (bootstrappedPod, bootstrappedContainer) = - initContainerBootstrap.map { bootstrap => - val podWithInitContainer = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - maybeSecretsMountedPod, - new ContainerBuilder().build(), - maybeSecretsMountedContainer)) - - val (pod, mayBeSecretsMountedInitContainer) = - initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) - }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) - - val bootstrappedPod = KubernetesUtils.appendInitContainer( - pod, mayBeSecretsMountedInitContainer) - - (bootstrappedPod, podWithInitContainer.mainContainer) - }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) - - new PodBuilder(bootstrappedPod) - .editSpec() - .addToContainers(bootstrappedContainer) - .endSpec() - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index a942db6ae02d..0ea80dfbc0d9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -33,7 +33,9 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && sc.deployMode == "client") { + if (masterURL.startsWith("k8s") && + sc.deployMode == "client" && + !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { throw new SparkException("Client mode is currently not supported for Kubernetes.") } @@ -44,75 +46,16 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val sparkConf = sc.getConf - val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) - val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) - - if (initContainerConfigMap.isEmpty) { - logWarning("The executor's init-container config map is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - if (initContainerConfigMapKey.isEmpty) { - logWarning("The executor's init-container config map key is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - // Only set up the bootstrap if they've provided both the config map key and the config map - // name. The config map might not be provided if init-containers aren't being used to - // bootstrap dependencies. - val initContainerBootstrap = for { - configMap <- initContainerConfigMap - configMapKey <- initContainerConfigMapKey - } yield { - val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - new InitContainerBootstrap( - initContainerImage, - sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), - sparkConf.get(JARS_DOWNLOAD_LOCATION), - sparkConf.get(FILES_DOWNLOAD_LOCATION), - configMap, - configMapKey, - SPARK_POD_EXECUTOR_ROLE, - sparkConf) - } - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - // Mount user-specified executor secrets also into the executor's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The executor's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && - executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - + sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, - Some(sparkConf.get(KUBERNETES_NAMESPACE)), + Some(sc.conf.get(KUBERNETES_NAMESPACE)), KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - sparkConf, + sc.conf, Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory( - sparkConf, - mountSecretBootstrap, - initContainerBootstrap, - initContainerMountSecretsBootstrap) - val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( @@ -120,7 +63,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - executorPodFactory, + new KubernetesExecutorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 9de4b16c30d3..d86664c81071 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorPodFactory: ExecutorPodFactory, + executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, allocatorExecutor: ScheduledExecutorService, requestExecutorsService: ExecutorService) @@ -115,14 +116,19 @@ private[spark] class KubernetesClusterSchedulerBackend( for (_ <- 0 until math.min( currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorPod = executorPodFactory.createExecutorPod( + val executorConf = KubernetesConf.createExecutorConf( + conf, executorId, applicationId(), - driverUrl, - conf.getExecutorEnv, - driverPod, - currentNodeToLocalTaskCount) - executorsToAllocate(executorId) = executorPod + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + + executorsToAllocate(executorId) = podWithAttachedContainer logInfo( s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala new file mode 100644 index 000000000000..22568fe7ea3b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesExecutorBuilder( + provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_), + provideSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + var executorPod = SparkPod.initialPod() + for (feature <- allFeatures) { + executorPod = feature.configurePod(executorPod) + } + executorPod + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala new file mode 100644 index 000000000000..f10202f7a354 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource + +class KubernetesConfSuite extends SparkFunSuite { + + private val APP_NAME = "test-app" + private val RESOURCE_NAME_PREFIX = "prefix" + private val APP_ID = "test-id" + private val MAIN_CLASS = "test-class" + private val APP_ARGS = Array("arg1", "arg2") + private val CUSTOM_LABELS = Map( + "customLabel1Key" -> "customLabel1Value", + "customLabel2Key" -> "customLabel2Value") + private val CUSTOM_ANNOTATIONS = Map( + "customAnnotation1Key" -> "customAnnotation1Value", + "customAnnotation2Key" -> "customAnnotation2Value") + private val SECRET_NAMES_TO_MOUNT_PATHS = Map( + "secret1" -> "/mnt/secrets/secret1", + "secret2" -> "/mnt/secrets/secret2") + private val CUSTOM_ENVS = Map( + "customEnvKey1" -> "customEnvValue1", + "customEnvKey2" -> "customEnvValue2") + private val DRIVER_POD = new PodBuilder().build() + private val EXECUTOR_ID = "executor-id" + + test("Basic driver translated fields.") { + val sparkConf = new SparkConf(false) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.appId === APP_ID) + assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) + assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) + assert(conf.roleSpecificConf.appName === APP_NAME) + assert(conf.roleSpecificConf.mainAppResource.isEmpty) + assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) + assert(conf.roleSpecificConf.appArgs === APP_ARGS) + } + + test("Creating driver conf with and without the main app jar influences spark.jars") { + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) + val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppJar, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") + .split(",") + === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) + val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + } + + test("Resolve driver labels, annotations, secret mount paths, and envs.") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) + } + CUSTOM_ENVS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) + } + + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.roleLabels === Map( + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ + CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleEnvs === CUSTOM_ENVS) + } + + test("Basic executor translated fields.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) + assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + } + + test("Image pull secrets.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false) + .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.imagePullSecrets() === + Seq( + new LocalObjectReferenceBuilder().withName("my-secret-1").build(), + new LocalObjectReferenceBuilder().withName("my-secret-2").build())) + } + + test("Set executor labels, annotations, and secrets") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) + } + + val conf = KubernetesConf.createExecutorConf( + sparkConf, + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleLabels === Map( + SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded26..16780584a674 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala new file mode 100644 index 000000000000..eee85b8baa73 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class BasicDriverFeatureStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) + private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ENVS = Map( + DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, + DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + + test("Check the pod respects all configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + DRIVER_ENVS) + + val featureStep = new BasicDriverFeatureStep(kubernetesConf) + val basePod = SparkPod.initialPod() + val configuredPod = featureStep.configurePod(basePod) + + assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getImage === "spark-driver:latest") + assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + + assert(configuredPod.container.getEnv.size === 3) + val envs = configuredPod.container + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) + assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + + assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === + TEST_IMAGE_PULL_SECRET_OBJECTS) + + assert(configuredPod.container.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = configuredPod.container.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "456Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = configuredPod.pod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") + + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") + assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) + } + + test("Additional system properties resolve jars and set cluster-mode confs.") { + val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") + val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .setJars(allJars) + .set("spark.files", allFiles.mkString(",")) + .set(CONTAINER_IMAGE, "spark-driver:latest") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty) + val step = new BasicDriverFeatureStep(kubernetesConf) + val additionalProperties = step.getAdditionalPodSystemProperties() + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true", + "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + assert(additionalProperties === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala new file mode 100644 index 000000000000..a764f7630b5c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +class BasicExecutorFeatureStepSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + + private val APP_ID = "app-id" + private val DRIVER_HOSTNAME = "localhost" + private val DRIVER_PORT = 7098 + private val DRIVER_ADDRESS = RpcEndpointAddress( + DRIVER_HOSTNAME, + DRIVER_PORT.toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + private val DRIVER_POD_NAME = "driver-pod" + + private val DRIVER_POD_UID = "driver-uid" + private val RESOURCE_NAME_PREFIX = "base" + private val EXECUTOR_IMAGE = "executor-image" + private val LABELS = Map("label1key" -> "label1value") + private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") + private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + private val DRIVER_POD = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .withUid(DRIVER_POD_UID) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) + .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set("spark.driver.host", DRIVER_HOSTNAME) + .set("spark.driver.port", DRIVER_PORT.toString) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + } + + test("basic executor pod has reasonable defaults") { + val step = new BasicExecutorFeatureStep( + KubernetesConf( + baseConf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + val executor = step.configurePod(SparkPod.initialPod()) + + // The executor pod name and default labels. + assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") + assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.container.getImage === EXECUTOR_IMAGE) + assert(executor.container.getVolumeMounts.isEmpty) + assert(executor.container.getResources.getLimits.size() === 1) + assert(executor.container.getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.pod.getSpec.getNodeSelector.isEmpty) + assert(executor.pod.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + longPodNamePrefix, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map("qux" -> "quux"))) + val executor = step.configurePod(SparkPod.initialPod()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + ENV_CLASSPATH -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> APP_ID, + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val mapEnvs = executorPod.container.getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala similarity index 67% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 64553d25883b..9f817d3bfc79 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -14,34 +14,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features import java.io.File -import scala.collection.JavaConverters._ - import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.util.Utils -class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private val APP_ID = "k8s-app" private var credentialsTempDirectory: File = _ - private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) + private val BASE_DRIVER_POD = SparkPod.initialPod() + + @Mock + private var driverSpecificConf: KubernetesDriverSpecificConf = _ before { + MockitoAnnotations.initMocks(this) credentialsTempDirectory = Utils.createTempDir() } @@ -50,13 +51,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA } test("Don't set any credentials") { - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + val kubernetesConf = KubernetesConf( + new SparkConf(false), + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) } test("Only set credentials that are manually mounted.") { @@ -73,14 +80,23 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() + resolvedProperties.foreach { case (propKey, propValue) => + assert(submissionSparkConf.get(propKey) === propValue) + } } test("Mount credentials from the submission client as a secret.") { @@ -100,10 +116,17 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( - BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> @@ -113,16 +136,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> DRIVER_CREDENTIALS_CLIENT_CERT_PATH, s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - DRIVER_CREDENTIALS_CA_CERT_PATH, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> - clientKeyFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> - clientCertFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - caCertFile.getAbsolutePath) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - assert(preparedDriverSpec.otherKubernetesResources.size === 1) - val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + DRIVER_CREDENTIALS_CA_CERT_PATH) + assert(resolvedProperties === expectedSparkConf) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1) + val credentialsSecret = kubernetesCredentialsStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => @@ -134,12 +154,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") assert(decodedSecretData === expectedSecretData) - val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) + val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala assert(driverPodVolumes.size === 1) assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverPodVolumes.head.getSecret != null) assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) - val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala assert(driverContainerVolumeMount.size === 1) assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala new file mode 100644 index 000000000000..c299d56865ec --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Clock + +class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + sparkConf = sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) + assert(configurationStep.getAdditionalKubernetesResources().size === 1) + assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val resolvedService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + resolvedService) + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) + === DEFAULT_BLOCKMANAGER_PORT.toString) + } + + test("Long prefixes should switch to using a generated name.") { + when(clock.getTimeMillis()).thenReturn(10000) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: Map[String, String], expectedHostName: String): Unit = { + assert(driverSparkConf( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala new file mode 100644 index 000000000000..27bff74ce38a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.deploy.k8s.SparkPod + +object KubernetesFeaturesTestUtils { + + def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep]( + stepType: String, stepClass: Class[T]): T = { + val mockStep = mock(stepClass) + when(mockStep.getAdditionalKubernetesResources()).thenReturn( + getSecretsForStepType(stepType)) + + when(mockStep.getAdditionalPodSystemProperties()) + .thenReturn(Map(stepType -> stepType)) + when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + .thenAnswer(new Answer[SparkPod]() { + override def answer(invocation: InvocationOnMock): SparkPod = { + val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val configuredPod = new PodBuilder(originalPod.pod) + .editOrNewMetadata() + .addToLabels(stepType, stepType) + .endMetadata() + .build() + SparkPod(configuredPod, originalPod.container) + } + }) + mockStep + } + + def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String) + : Seq[HasMetadata] = { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(stepType) + .endMetadata() + .build()) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala similarity index 64% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9ec0cb55de5a..9d02f56cc206 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -14,29 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} -class DriverMountSecretsStepSuite extends SparkFunSuite { +class MountSecretsFeatureStepSuite extends SparkFunSuite { private val SECRET_FOO = "foo" private val SECRET_BAR = "bar" private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("mounts all given secrets") { - val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val baseDriverPod = SparkPod.initialPod() val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + secretNamesToMountPaths, + Map.empty) - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) - val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) - val driverPodWithSecretsMounted = configuredDriverSpec.driverPod - val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + val step = new MountSecretsFeatureStep(kubernetesConf) + val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod + val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index bf4ec0489320..c1b203e03a35 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -16,28 +16,92 @@ */ package org.apache.spark.deploy.k8s.submit -import scala.collection.JavaConverters._ - -import com.google.common.collect.Iterables import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_UID = "pod-id" private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" + private val KUBERNETES_RESOURCE_PREFIX = "resource-example" + private val POD_NAME = "driver" + private val CONTAINER_NAME = "container" + private val APP_ID = "app-id" + private val APP_NAME = "app" + private val MAIN_CLASS = "main" + private val APP_ARGS = Seq("arg1", "arg2") + private val RESOLVED_JAVA_OPTIONS = Map( + "conf1key" -> "conf1value", + "conf2key" -> "conf2value") + private val BUILT_DRIVER_POD = + new PodBuilder() + .withNewMetadata() + .withName(POD_NAME) + .endMetadata() + .withNewSpec() + .withHostname("localhost") + .endSpec() + .build() + private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build() + private val ADDITIONAL_RESOURCES = Seq( + new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build()) + + private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec( + SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER), + ADDITIONAL_RESOURCES, + RESOLVED_JAVA_OPTIONS) + + private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() + .build() + private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD) + .editSpec() + .addToContainers(FULL_EXPECTED_CONTAINER) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap() + .endVolume() + .endSpec() + .build() + + private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + + private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret => + new SecretBuilder(secret) + .editMetadata() + .addNewOwnerReference() + .withName(POD_NAME) + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .withController(true) + .withUid(DRIVER_POD_UID) + .endOwnerReference() + .endMetadata() + .build() + } private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -55,180 +119,92 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + @Mock + private var driverBuilder: KubernetesDriverBuilder = _ + @Mock private var resourceList: ResourceList = _ - private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ + + private var sparkConf: SparkConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ before { MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( + sparkConf, + KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KUBERNETES_RESOURCE_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.withName(POD_NAME)).thenReturn(namedPods) createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) - when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { - override def answer(invocation: InvocationOnMock): Pod = { - new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) - .editMetadata() - .withUid(DRIVER_POD_UID) - .endMetadata() - .withApiVersion(DRIVER_POD_API_VERSION) - .withKind(DRIVER_POD_KIND) - .build() - } - }) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE) when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) doReturn(resourceList) .when(kubernetesClient) .resourceList(createdResourcesArgumentCaptor.capture()) } - test("The client should configure the pod with the submission steps.") { + test("The client should configure the pod using the builder.") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) - assert(createdPod.getMetadata.getLabels.asScala === - Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) - assert(createdPod.getMetadata.getAnnotations.asScala === - Map(SecondTestConfigurationStep.annotationKey -> - SecondTestConfigurationStep.annotationValue)) - assert(createdPod.getSpec.getContainers.size() === 1) - assert(createdPod.getSpec.getContainers.get(0).getName === - SecondTestConfigurationStep.containerName) + verify(podOperations).create(FULL_EXPECTED_POD) } - test("The client should create the secondary Kubernetes resources.") { + test("The client should create Kubernetes resources") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues - assert(otherCreatedResources.size === 1) - val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret] - assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(createdResource.getData.asScala === - Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences) - assert(ownerReference.getName === createdPod.getMetadata.getName) - assert(ownerReference.getKind === DRIVER_POD_KIND) - assert(ownerReference.getUid === DRIVER_POD_UID) - assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) - } - - test("The client should attach the driver container with the appropriate JVM options.") { - val sparkConf = new SparkConf(false) - .set("spark.logConf", "true") - .set( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, - "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails") - val submissionClient = new Client( - submissionSteps, - sparkConf, - kubernetesClient, - false, - "spark", - loggingPodStatusWatcher) - submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) - assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env => - env.getName.startsWith(ENV_JAVA_OPT_PREFIX) - }.sortBy(_.getName) - assert(driverJvmOptsEnvs.size === 4) - - val expectedJvmOptsValues = Seq( - "-Dspark.logConf=true", - s"-D${SecondTestConfigurationStep.sparkConfKey}=" + - s"${SecondTestConfigurationStep.sparkConfValue}", - "-XX:+HeapDumpOnOutOfMemoryError", - "-XX:+PrintGCDetails") - driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach { - case ((resolvedEnv, expectedJvmOpt), index) => - assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index") - assert(resolvedEnv.getValue === expectedJvmOpt) - } + assert(otherCreatedResources.size === 2) + val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq + assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES) + val configMaps = otherCreatedResources.toArray + .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) + assert(secrets.nonEmpty) + assert(configMaps.nonEmpty) + val configMap = configMaps.head + assert(configMap.getMetadata.getName === + s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") + assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value")) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value")) } test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, true, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } - -} - -private object FirstTestConfigurationStep extends DriverConfigurationStep { - - val podName = "test-pod" - val secretName = "test-secret" - val labelKey = "first-submit" - val labelValue = "true" - val secretKey = "secretKey" - val secretData = "secretData" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .withName(podName) - .addToLabels(labelKey, labelValue) - .endMetadata() - .build() - val additionalResource = new SecretBuilder() - .withNewMetadata() - .withName(secretName) - .endMetadata() - .addToData(secretKey, secretData) - .build() - driverSpec.copy( - driverPod = modifiedPod, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) - } -} - -private object SecondTestConfigurationStep extends DriverConfigurationStep { - - val annotationKey = "second-submit" - val annotationValue = "submitted" - val sparkConfKey = "spark.custom-conf" - val sparkConfValue = "custom-conf-value" - val containerName = "driverContainer" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .addToAnnotations(annotationKey, annotationValue) - .endMetadata() - .build() - val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) - val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(containerName) - .build() - driverSpec.copy( - driverPod = modifiedPod, - driverSparkConf = resolvedSparkConf, - driverContainer = modifiedContainer) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala deleted file mode 100644 index f193b1f4d366..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.steps._ - -class DriverConfigOrchestratorSuite extends SparkFunSuite { - - private val DRIVER_IMAGE = "driver-image" - private val IC_IMAGE = "init-container-image" - private val APP_ID = "spark-app-id" - private val LAUNCH_TIME = 975256L - private val APP_NAME = "spark" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2") - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/driver" - - test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep] - ) - } - - test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Option.empty, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep] - ) - } - - test("Submission steps with an init-container.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE, IC_IMAGE) - .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverInitContainerBootstrapStep]) - } - - test("Submission steps with driver secrets to mount") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverMountSecretsStep]) - } - - private def validateStepTypes( - orchestrator: DriverConfigOrchestrator, - types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps - assert(steps.size === types.size) - assert(steps.map(_.getClass) === types) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala new file mode 100644 index 000000000000..161f9afe7bba --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesDriverBuilderSuite extends SparkFunSuite { + + private val BASIC_STEP_TYPE = "basic" + private val CREDENTIALS_STEP_TYPE = "credentials" + private val SERVICE_STEP_TYPE = "service" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) + + private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) + + private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest: KubernetesDriverBuilder = + new KubernetesDriverBuilder( + _ => basicFeatureStep, + _ => credentialsStep, + _ => serviceStep, + _ => secretsStep) + + test("Apply fundamental steps all the time.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) + : Unit = { + assert(resolvedSpec.systemProperties.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) + assert(resolvedSpec.driverKubernetesResources.containsSlice( + KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) + assert(resolvedSpec.systemProperties(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala deleted file mode 100644 index e864c6a16eeb..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class BasicDriverConfigurationStepSuite extends SparkFunSuite { - - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" - - test("Set all possible configurations from the user.") { - val sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") - .set("spark.driver.cores", "2") - .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(DRIVER_CONTAINER_IMAGE, "spark-driver:latest") - .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - - val submissionStep = new BasicDriverConfigurationStep( - APP_ID, - RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - CONTAINER_IMAGE_PULL_POLICY, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = basePod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - - assert(preparedDriverSpec.driverContainer.getEnv.size === 7) - val envs = preparedDriverSpec.driverContainer - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(ENV_DRIVER_MEMORY) === "256M") - assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") - assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") - assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") - - assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => - envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && - envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && - envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) - - val resourceRequirements = preparedDriverSpec.driverContainer.getResources - val requests = resourceRequirements.getRequests.asScala - assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "256Mi") - val limits = resourceRequirements.getLimits.asScala - assert(limits("memory").getAmount === "456Mi") - assert(limits("cpu").getAmount === "4") - - val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata - assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) - val expectedAnnotations = Map( - CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, - SPARK_APP_NAME_ANNOTATION -> APP_NAME) - assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never") - - val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap - val expectedSparkConf = Map( - KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX) - assert(resolvedSparkConf === expectedSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala deleted file mode 100644 index 991b03cafb76..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class DependencyResolutionStepSuite extends SparkFunSuite { - - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/apps/jars/jar1.jar", - "file:///home/user/apps/jars/jar2.jar", - "local:///var/apps/jars/jar3.jar") - - private val SPARK_FILES = Seq( - "file:///home/user/apps/files/file1.txt", - "hdfs://localhost:9000/apps/files/file2.txt", - "local:///var/apps/files/file3.txt") - - private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars" - private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files" - - test("Added dependencies should be resolved in Spark configuration and environment") { - val dependencyResolutionStep = new DependencyResolutionStep( - SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH) - val driverPod = new PodBuilder().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = driverPod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) - assert(preparedDriverSpec.driverPod === driverPod) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet - val expectedResolvedSparkJars = Set( - "hdfs://localhost:9000/apps/jars/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") - assert(resolvedSparkJars === expectedResolvedSparkJars) - val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet - val expectedResolvedSparkFiles = Set( - s"$FILES_DOWNLOAD_PATH/file1.txt", - s"hdfs://localhost:9000/apps/files/file2.txt", - s"/var/apps/files/file3.txt") - assert(resolvedSparkFiles === expectedResolvedSparkFiles) - val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(driverEnv.size === 1) - assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) - val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = Set( - s"$JARS_DOWNLOAD_PATH/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") - assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala deleted file mode 100644 index 758871e2ba35..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringReader -import java.util.Properties - -import scala.collection.JavaConverters._ - -import com.google.common.collect.Maps -import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} -import org.apache.spark.util.Utils - -class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { - - private val CONFIG_MAP_NAME = "spark-init-config-map" - private val CONFIG_MAP_KEY = "spark-init-config-map-key" - - test("The init container bootstrap step should use all of the init container steps") { - val baseDriverSpec = KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val initContainerSteps = Seq( - FirstTestInitContainerConfigurationStep, - SecondTestInitContainerConfigurationStep) - val bootstrapStep = new DriverInitContainerBootstrapStep( - initContainerSteps, - CONFIG_MAP_NAME, - CONFIG_MAP_KEY) - - val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === - FirstTestInitContainerConfigurationStep.additionalLabels) - val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(additionalDriverEnv.size === 1) - assert(additionalDriverEnv.head.getName === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) - assert(additionalDriverEnv.head.getValue === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) - - assert(preparedDriverSpec.otherKubernetesResources.size === 2) - assert(preparedDriverSpec.otherKubernetesResources.contains( - FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) - assert(preparedDriverSpec.otherKubernetesResources.exists { - case configMap: ConfigMap => - val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME - val configMapData = configMap.getData.asScala - val hasCorrectNumberOfEntries = configMapData.size == 1 - val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) - val initContainerProperties = new Properties() - Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { - initContainerProperties.load(_) - } - val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala - val expectedInitContainerProperties = Map( - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) - val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties - hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties - - case _ => false - }) - - val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers - assert(initContainers.size() === 1) - val initContainerEnv = initContainers.get(0).getEnv.asScala - assert(initContainerEnv.size === 1) - assert(initContainerEnv.head.getName === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) - assert(initContainerEnv.head.getValue === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) - - val expectedSparkConf = Map( - INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - } -} - -private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - - val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") - val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" - val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" - val additionalKubernetesResource = new SecretBuilder() - .withNewMetadata() - .withName("test-secret") - .endMetadata() - .addToData("secret-key", "secret-value") - .build() - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val driverPod = new PodBuilder(initContainerSpec.driverPod) - .editOrNewMetadata() - .addToLabels(additionalLabels.asJava) - .endMetadata() - .build() - val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) - .addNewEnv() - .withName(additionalMainContainerEnvKey) - .withValue(additionalMainContainerEnvValue) - .endEnv() - .build() - initContainerSpec.copy( - driverPod = driverPod, - driverContainer = mainContainer, - dependentResources = initContainerSpec.dependentResources ++ - Seq(additionalKubernetesResource)) - } -} - -private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" - val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" - val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" - val additionalInitContainerPropertyValue = "testvalue" - val additionalDriverSparkConfKey = "spark.driver.testkey" - val additionalDriverSparkConfValue = "spark.driver.testvalue" - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val initContainer = new ContainerBuilder(initContainerSpec.initContainer) - .addNewEnv() - .withName(additionalInitContainerEnvKey) - .withValue(additionalInitContainerEnvValue) - .endEnv() - .build() - val initContainerProperties = initContainerSpec.properties ++ - Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) - val driverSparkConf = initContainerSpec.driverSparkConf ++ - Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) - initContainerSpec.copy( - initContainer = initContainer, - properties = initContainerProperties, - driverSparkConf = driverSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala deleted file mode 100644 index 006ce2668f8a..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.util.Clock - -class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) - - private val LONG_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) - private val DRIVER_LABELS = Map( - "label1key" -> "label1value", - "label2key" -> "label2value") - - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - - test("Headless service has a port for the driver RPC and the block manager.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - assert(resolvedDriverSpec.otherKubernetesResources.size === 1) - assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - verifyService( - 9000, - 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - driverService) - } - - test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf, - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - verifyService( - DEFAULT_DRIVER_PORT, - DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) - assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === - DEFAULT_DRIVER_PORT.toString) - assert(resolvedDriverSpec.driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) - } - - test("Long prefixes should switch to using a generated name.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - when(clock.getTimeMillis()).thenReturn(10000) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" - assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Disallow bind address and driver host to be set explicitly.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - clock) - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") - } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") - } - } - - private def verifyService( - driverPort: Int, - blockManagerPort: Int, - expectedServiceName: String, - service: Service): Unit = { - assert(service.getMetadata.getName === expectedServiceName) - assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) - assert(service.getSpec.getPorts.size() === 2) - val driverServicePorts = service.getSpec.getPorts.asScala - assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) - assert(driverServicePorts.head.getPort.intValue() === driverPort) - assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) - assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) - assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) - assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) - } - - private def verifySparkConfHostNames( - driverSparkConf: SparkConf, expectedHostName: String): Unit = { - assert(driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala deleted file mode 100644 index 4553f9f6b1d4..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito.when -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ - -class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val POD_LABEL = Map("bootstrap" -> "true") - private val INIT_CONTAINER_NAME = "init-container" - private val DRIVER_CONTAINER_NAME = "driver-container" - - @Mock - private var podAndInitContainerBootstrap : InitContainerBootstrap = _ - - before { - MockitoAnnotations.initMocks(this) - when(podAndInitContainerBootstrap.bootstrapInitContainer( - any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { - override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { - val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) - pod.copy( - pod = new PodBuilder(pod.pod) - .withNewMetadata() - .addToLabels("bootstrap", "true") - .endMetadata() - .withNewSpec().endSpec() - .build(), - initContainer = new ContainerBuilder() - .withName(INIT_CONTAINER_NAME) - .build(), - mainContainer = new ContainerBuilder() - .withName(DRIVER_CONTAINER_NAME) - .build() - )}}) - } - - test("additionalDriverSparkConf with mix of remote files and jars") { - val baseInitStep = new BasicInitContainerConfigurationStep( - SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - podAndInitContainerBootstrap) - val expectedDriverSparkConf = Map( - JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, - INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", - INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") - val initContainerSpec = InitContainerSpec( - Map.empty[String, String], - Map.empty[String, String], - new Container(), - new Container(), - new Pod, - Seq.empty[HasMetadata]) - val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) - assert(expectedDriverSparkConf === returnContainerSpec.properties) - assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) - assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala deleted file mode 100644 index 20f2e5bc15df..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -class InitContainerConfigOrchestratorSuite extends SparkFunSuite { - - private val DOCKER_IMAGE = "init-container" - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" - private val CUSTOM_LABEL_KEY = "customLabel" - private val CUSTOM_LABEL_VALUE = "customLabelValue" - private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" - private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("including basic configuration step") { - val sparkConf = new SparkConf(true) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.lengthCompare(1) == 0) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - } - - test("including step to mount user-specified secrets") { - val sparkConf = new SparkConf(false) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.length === 2) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala deleted file mode 100644 index eab4e1765945..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils - -class InitContainerMountSecretsStepSuite extends SparkFunSuite { - - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("mounts all given secrets") { - val baseInitContainerSpec = InitContainerSpec( - Map.empty, - Map.empty, - new ContainerBuilder().build(), - new ContainerBuilder().build(), - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - Seq.empty) - val secretNamesToMountPaths = Map( - SECRET_FOO -> SECRET_MOUNT_PATH, - SECRET_BAR -> SECRET_MOUNT_PATH) - - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) - val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( - baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod - val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.containerHasVolume( - initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala deleted file mode 100644 index 6c557ec4a7c9..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.rest.k8s - -import java.io.File -import java.util.UUID - -import com.google.common.base.Charsets -import com.google.common.io.Files -import org.mockito.Mockito -import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.util.Utils - -class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { - - private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") - private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") - - private var downloadJarsDir: File = _ - private var downloadFilesDir: File = _ - private var downloadJarsSecretValue: String = _ - private var downloadFilesSecretValue: String = _ - private var fileFetcher: FileFetcher = _ - - override def beforeAll(): Unit = { - downloadJarsSecretValue = Files.toString( - new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) - downloadFilesSecretValue = Files.toString( - new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) - } - - before { - downloadJarsDir = Utils.createTempDir() - downloadFilesDir = Utils.createTempDir() - fileFetcher = mock[FileFetcher] - } - - after { - downloadJarsDir.delete() - downloadFilesDir.delete() - } - - test("Downloads from remote server should invoke the file fetcher") { - val sparkConf = getSparkConfForRemoteFileDownloads - val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) - initContainerUnderTest.run() - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) - } - - private def getSparkConfForRemoteFileDownloads: SparkConf = { - new SparkConf(true) - .set(INIT_CONTAINER_REMOTE_JARS, - "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") - .set(INIT_CONTAINER_REMOTE_FILES, - "http://localhost:9000/file.txt") - .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) - .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) - } - - private def createTempFile(extension: String): String = { - val dir = Utils.createTempDir() - val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") - Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) - file.getAbsolutePath - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala deleted file mode 100644 index 7121a802c69c..000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.{AdditionalAnswers, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { - - private val driverPodName: String = "driver-pod" - private val driverPodUid: String = "driver-uid" - private val executorPrefix: String = "base" - private val executorImage: String = "executor-image" - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(driverPodName) - .withUid(driverPodUid) - .endMetadata() - .withNewSpec() - .withNodeName("some-node") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private var baseConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - baseConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(EXECUTOR_CONTAINER_IMAGE, executorImage) - } - - test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None, None, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - // The executor pod name and default labels. - assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") - assert(executor.getMetadata.getLabels.size() === 3) - assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - - // There is exactly 1 container with no volume mounts and default memory limits. - // Default memory limit is 1024M + 384M (minimum overhead constant). - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getImage === executorImage) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) - assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources - .getLimits.get("memory").getAmount === "1408Mi") - - // The pod has no node selector, volumes. - assert(executor.getSpec.getNodeSelector.isEmpty) - assert(executor.getSpec.getVolumes.isEmpty) - - checkEnv(executor, Map()) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() - conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, - "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - - val factory = new ExecutorPodFactory(conf, None, None, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getHostname.length === 63) - } - - test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val factory = new ExecutorPodFactory(conf, None, None, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) - - checkEnv(executor, - Map("SPARK_JAVA_OPT_0" -> "foo=bar", - "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", - "qux" -> "quux")) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor secrets get mounted") { - val conf = baseConf.clone() - - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - None, - None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - // check volume mounted. - assert(executor.getSpec.getVolumes.size() === 1) - assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - test("init-container bootstrap step adds an init container") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - - val factory = new ExecutorPodFactory( - conf, - None, - Some(initContainerBootstrap), - None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getInitContainers.size() === 1) - checkOwnerReferences(executor, driverPodUid) - } - - test("init-container with secrets mount bootstrap") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - - val factory = new ExecutorPodFactory( - conf, - None, - Some(initContainerBootstrap), - Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - // There is always exactly one controller reference, and it points to the driver pod. - private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { - assert(executor.getMetadata.getOwnerReferences.size() === 1) - assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) - assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) - } - - // Check that the expected environment variables are present. - private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { - val defaultEnvs = Map( - ENV_EXECUTOR_ID -> "1", - ENV_DRIVER_URL -> "dummy", - ENV_EXECUTOR_CORES -> "1", - ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) - val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { - x => (x.getName, x.getValue) - }.toMap - assert(defaultEnvs === mapEnvs) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index b2f26f205a32..96065e83f069 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} -import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.hamcrest.{BaseMatcher, Description, Matcher} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito.{doNothing, never, times, verify, when} import org.scalatest.BeforeAndAfter @@ -31,6 +32,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc._ @@ -47,8 +49,6 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 private val POD_ALLOCATION_INTERVAL = "1m" - private val DRIVER_URL = RpcEndpointAddress( - SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() .withNewMetadata() .withName("pod1") @@ -94,7 +94,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var requestExecutorsService: ExecutorService = _ @Mock - private var executorPodFactory: ExecutorPodFactory = _ + private var executorBuilder: KubernetesExecutorBuilder = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -399,7 +399,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, - executorPodFactory, + executorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) { @@ -428,13 +428,22 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) .endMetadata() .build() - when(executorPodFactory.createExecutorPod( - executorId.toString, - APP_ID, - DRIVER_URL, - sparkConf.getExecutorEnv, - driverPod, - Map.empty)).thenReturn(resolvedPod) - resolvedPod + val resolvedContainer = new ContainerBuilder().build() + when(executorBuilder.buildFromFeatures(Matchers.argThat( + new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any) + : Boolean = { + argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && + argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + .roleSpecificConf.executorId == executorId.toString + } + + override def describeTo(description: Description): Unit = {} + }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) + new PodBuilder(resolvedPod) + .editSpec() + .addToContainers(resolvedContainer) + .endSpec() + .build() } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala new file mode 100644 index 000000000000..f5270623f8ac --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesExecutorBuilderSuite extends SparkFunSuite { + private val BASIC_STEP_TYPE = "basic" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) + private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest = new KubernetesExecutorBuilder( + _ => basicFeatureStep, + _ => mountSecretsStep) + + test("Basic steps are consistently applied.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { + assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile deleted file mode 100644 index 45fbcd9cd0de..000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile deleted file mode 100644 index 0f806cf7e148..000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile similarity index 85% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile rename to resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 222e777db3a8..9badf8556afc 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,12 +17,15 @@ FROM openjdk:8-alpine +ARG spark_jars=jars +ARG img_path=kubernetes/dockerfiles + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-base:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ apk upgrade --no-cache && \ @@ -34,11 +37,12 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh new file mode 100755 index 000000000000..3e166116aa3f --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# echo commands to the terminal output +set -ex + +# Check whether there is a passwd entry for the container UID +myuid=$(id -u) +mygid=$(id -g) +# turn off -e for getent because it will return error code in anonymous uid case +set +e +uidentry=$(getent passwd $myuid) +set -e + +# If there is no passwd entry for the container UID, attempt to create one +if [ -z "$uidentry" ] ; then + if [ -w /etc/passwd ] ; then + echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + else + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + fi +fi + +SPARK_K8S_CMD="$1" +if [ -z "$SPARK_K8S_CMD" ]; then + echo "No command to execute has been provided." 1>&2 + exit 1 +fi +shift 1 + +SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" +env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt +if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +fi +if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then + cp -R "$SPARK_MOUNTED_FILES_DIR/." . +fi + +case "$SPARK_K8S_CMD" in + driver) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" + ) + ;; + + executor) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_JAVA_OPTS[@]}" + -Xms$SPARK_EXECUTOR_MEMORY + -Xmx$SPARK_EXECUTOR_MEMORY + -cp "$SPARK_CLASSPATH" + org.apache.spark.executor.CoarseGrainedExecutorBackend + --driver-url $SPARK_DRIVER_URL + --executor-id $SPARK_EXECUTOR_ID + --cores $SPARK_EXECUTOR_CORES + --app-id $SPARK_APPLICATION_ID + --hostname $SPARK_EXECUTOR_POD_IP + ) + ;; + + *) + echo "Unknown command: $SPARK_K8S_CMD" 1>&2 + exit 1 +esac + +# Execute the container CMD under tini for better hygiene +exec /sbin/tini -s -- "${CMD[@]}" diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 70d0c1750b14..3995d0afeb5f 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index aa378c9d340f..ccf33e8d4283 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer @@ -100,7 +100,13 @@ private[mesos] object MesosClusterDispatcher Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf - val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + val dispatcherArgs = try { + new MesosClusterDispatcherArguments(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage()) + null + } conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 096bb4e1af68..267a4283db9e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkSubmitUtils import org.apache.spark.util.{IntParam, Utils} private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { @@ -95,9 +96,8 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: parse(tail) case ("--conf") :: value :: tail => - val pair = MesosClusterDispatcher. - parseSparkConfProperty(value) - confProperties(pair._1) = pair._2 + val (k, v) = SparkSubmitUtils.parseSparkConfProperty(value) + confProperties(k) = v parse(tail) case ("--help") :: tail => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820..b36f46456f9a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -530,9 +530,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca48..9b75e4c98344 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d24610..71a70ff048cc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 7165bfae18a5..a1bf4f0c048f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils @@ -63,7 +64,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") @@ -104,8 +105,10 @@ private[spark] class MesosHadoopDelegationTokenManager( } catch { case e: Exception => // Log the error and try to write new tokens back in an hour - logWarning("Couldn't broadcast tokens, trying again in an hour", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning( + s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e) + credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS) return } scheduleRenewal(this) @@ -135,7 +138,7 @@ private[spark] class MesosHadoopDelegationTokenManager( "related configurations in the target services.") currTime } else { - SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf) } logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index e75450369ad8..ecbcc960fc5a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.File +import java.nio.charset.StandardCharsets import java.util.{List => JList} import java.util.concurrent.CountDownLatch @@ -25,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter +import com.google.common.io.Files import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability @@ -71,26 +74,15 @@ trait MesosSchedulerUtils extends Logging { failoverTimeout: Option[Double] = None, frameworkId: Option[String] = None): SchedulerDriver = { val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) - val credBuilder = Credential.newBuilder() + fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS))) webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } frameworkId.foreach { id => fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) } - fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( - conf.get(DRIVER_HOST_ADDRESS))) - conf.getOption("spark.mesos.principal").foreach { principal => - fwInfoBuilder.setPrincipal(principal) - credBuilder.setPrincipal(principal) - } - conf.getOption("spark.mesos.secret").foreach { secret => - credBuilder.setSecret(secret) - } - if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { - throw new SparkException( - "spark.mesos.principal must be configured when spark.mesos.secret is set") - } + conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } @@ -98,6 +90,7 @@ trait MesosSchedulerUtils extends Logging { if (maxGpus > 0) { fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) } + val credBuilder = buildCredentials(conf, fwInfoBuilder) if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) @@ -106,6 +99,40 @@ trait MesosSchedulerUtils extends Logging { } } + def buildCredentials( + conf: SparkConf, + fwInfoBuilder: Protos.FrameworkInfo.Builder): Protos.Credential.Builder = { + val credBuilder = Credential.newBuilder() + conf.getOption("spark.mesos.principal") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL"))) + .orElse( + conf.getOption("spark.mesos.principal.file") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL_FILE"))) + .map { principalFile => + Files.toString(new File(principalFile), StandardCharsets.UTF_8) + } + ).foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET"))) + .orElse( + conf.getOption("spark.mesos.secret.file") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET_FILE"))) + .map { secretFile => + Files.toString(new File(secretFile), StandardCharsets.UTF_8) + } + ).foreach { secret => + credBuilder.setSecret(secret) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + credBuilder + } + /** * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. * This driver is expected to not be running. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 7df738958f85..8d90e1a8591a 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.{File, FileNotFoundException} + import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import org.apache.mesos.Protos.{Resource, Value} +import com.google.common.io.Files +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Value} import org.mockito.Mockito._ import org.scalatest._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.internal.config._ +import org.apache.spark.util.SparkConfWithEnv class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { @@ -237,4 +241,157 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} portsToUse.isEmpty shouldBe true } + + test("Principal specified via spark.mesos.principal") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", pFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "test-principal")) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> pFile.getAbsolutePath())) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE that does not exist") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> "/tmp/does-not-exist")) + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Secret specified via spark.mesos.secret") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", sFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_SECRET") { + val env = Map("SPARK_MESOS_SECRET" -> "my-secret") + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via SPARK_MESOS_SECRET_FILE") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + + val sFilePath = sFile.getAbsolutePath() + val env = Map("SPARK_MESOS_SECRET_FILE" -> sFilePath) + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Secret specified with no principal") { + val conf = new SparkConf() + conf.set("spark.mesos.secret", "my-secret") + + intercept[SparkException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "other-principal")) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Secret specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_SECRET" -> "other-secret")) + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } } diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 43a7ce95bd3d..37e25ceecb88 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d7263..595077e7e809 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Modifier} import java.net.{Socket, URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -29,7 +29,7 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -40,7 +40,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} +import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -78,42 +78,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) - private val ugi = { - val original = UserGroupInformation.getCurrentUser() - - // If a principal and keytab were provided, log in to kerberos, and set up a thread to - // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL - // of the TGT, use a configuration to define how often to check that a relogin is necessary. - // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. - val principal = sparkConf.get(PRINCIPAL).orNull - val keytab = sparkConf.get(KEYTAB).orNull - if (principal != null && keytab != null) { - UserGroupInformation.loginUserFromKeytab(principal, keytab) - - val renewer = new Thread() { - override def run(): Unit = Utils.tryLogNonFatalError { - while (true) { - TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) - UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() - } - } + private val userClassLoader = { + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + + if (isClusterMode) { + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } - renewer.setName("am-kerberos-renewer") - renewer.setDaemon(true) - renewer.start() - - // Transfer the original user's tokens to the new user, since that's needed to connect to - // YARN. It also copies over any delegation tokens that might have been created by the - // client, which will then be transferred over when starting executors (until new ones - // are created by the periodic task). - val newUser = UserGroupInformation.getCurrentUser() - SparkHadoopUtil.get.transferCredentials(original, newUser) - newUser } else { - SparkHadoopUtil.get.createSparkUser() + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } } + private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ => + new AMCredentialRenewer(sparkConf, yarnConf) + } + + private val ugi = credentialRenewer match { + case Some(cr) => + // Set the context class loader so that the token renewer has access to jars distributed + // by the user. + val currentLoader = Thread.currentThread().getContextClassLoader() + Thread.currentThread().setContextClassLoader(userClassLoader) + try { + cr.start() + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + + case _ => + SparkHadoopUtil.get.createSparkUser() + } + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic @@ -147,23 +148,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // A flag to check whether user has initialized spark context @volatile private var registered = false - private val userClassLoader = { - val classpath = Client.getUserClasspath(sparkConf) - val urls = classpath.map { entry => - new URL("file:" + new File(entry.getPath()).getAbsolutePath()) - } - - if (isClusterMode) { - if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { - new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } - // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() @@ -188,8 +172,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. private val sparkContextPromise = Promise[SparkContext]() - private var credentialRenewer: AMCredentialRenewer = _ - // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. @@ -315,31 +297,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - // If the credentials file config is present, we must periodically renew tokens. So create - // a new AMDelegationTokenRenewer - if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { - // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the - // classloader so that main jar and secondary jars could be used by AMCredentialRenewer. - val credentialRenewerThread = new Thread { - setName("AMCredentialRenewerStarter") - setContextClassLoader(userClassLoader) - - override def run(): Unit = { - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - yarnConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - - val credentialRenewer = - new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) - credentialRenewer.scheduleLoginFromKeytab() - } - } - - credentialRenewerThread.start() - credentialRenewerThread.join() - } - if (isClusterMode) { runDriver() } else { @@ -351,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) } } @@ -408,16 +365,27 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logDebug("shutting down user thread") userClassThread.interrupt() } - if (!inShutdown && credentialRenewer != null) { - credentialRenewer.stop() - credentialRenewer = null + if (!inShutdown) { + credentialRenewer.foreach(_.stop()) } } } } private def sparkContextInitialized(sc: SparkContext) = { - sparkContextPromise.success(sc) + sparkContextPromise.synchronized { + // Notify runDriver function that SparkContext is available + sparkContextPromise.success(sc) + // Pause the user class thread in order to make proper initialization in runDriver function. + sparkContextPromise.wait() + } + } + + private def resumeDriver(): Unit = { + // When initialization in runDriver happened the user class thread has to be resumed. + sparkContextPromise.synchronized { + sparkContextPromise.notify() + } } private def registerAM( @@ -427,11 +395,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -458,6 +423,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends securityMgr, localResources) + credentialRenewer.foreach(_.setDriverRef(driverRef)) + // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), // the allocator is ready to service requests. @@ -499,6 +466,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // if the user app did not create a SparkContext. throw new IllegalStateException("User did not initialize spark context!") } + resumeDriver() userClassThread.join() } catch { case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => @@ -508,6 +476,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") + } finally { + resumeDriver() } } @@ -705,9 +675,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { @@ -721,7 +696,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + cause) + "User class threw exception: " + StringUtils.stringifyException(cause)) } sparkContextPromise.tryFailure(e.getCause()) } finally { @@ -834,6 +809,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 15328d08b3b5..bafb129032b4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -93,11 +93,21 @@ private[spark] class Client( private val distCacheMgr = new ClientDistributedCacheManager() - private var loginFromKeytab = false - private var principal: String = null - private var keytab: String = null - private var credentials: Credentials = null - private var amKeytabFileName: String = null + private val principal = sparkConf.get(PRINCIPAL).orNull + private val keytab = sparkConf.get(KEYTAB).orNull + private val loginFromKeytab = principal != null + private val amKeytabFileName: String = { + require((principal == null) == (keytab == null), + "Both principal and keytab must be defined, or neither.") + if (loginFromKeytab) { + logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + new File(keytab).getName() + "-" + UUID.randomUUID().toString + } else { + null + } + } private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -120,11 +130,6 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) } @@ -145,9 +150,6 @@ private[spark] class Client( var appId: ApplicationId = null try { launcherBackend.connect() - // Setup the credentials before doing anything else, - // so we have don't have issues at any point. - setupCredentials() yarnClient.init(hadoopConf) yarnClient.start() @@ -288,8 +290,26 @@ private[spark] class Client( appContext } - /** Set up security tokens for launching our ApplicationMaster container. */ + /** + * Set up security tokens for launching our ApplicationMaster container. + * + * This method will obtain delegation tokens from all the registered providers, and set them in + * the AM's launch context. + */ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + credentialManager.obtainDelegationTokens(hadoopConf, credentials) + + // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid + // that for regular users, since in those case the user already has access to the TGT, + // and adding delegation tokens could lead to expired or cancelled tokens being used + // later, as reported in SPARK-15754. + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) amContainer.setTokens(ByteBuffer.wrap(dob.getData)) @@ -384,36 +404,6 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = destDir.getFileSystem(hadoopConf) - // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) - - if (credentials != null) { - // Add credentials to current user's UGI, so that following operations don't need to use the - // Kerberos tgt to get delegations again in the client side. - val currentUser = UserGroupInformation.getCurrentUser() - if (SparkHadoopUtil.get.isProxyUser(currentUser)) { - currentUser.addCredentials(credentials) - } - logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) - } - - // If we use principal and keytab to login, also credentials can be renewed some time - // after current time, we should pass the next renewal and updating time to credential - // renewer and updater. - if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() && - nearestTimeOfNextRenewal != Long.MaxValue) { - - // Valid renewal time is 75% of next renewal time, and the valid update time will be - // slightly later then renewal time (80% of next renewal time). This is to make sure - // credentials are renewed and updated before expired. - val currTime = System.currentTimeMillis() - val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime - val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime - - sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong) - sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong) - } - // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. @@ -696,7 +686,13 @@ private[spark] class Client( } } - Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + // SPARK-23630: during testing, Spark scripts filter out hadoop conf dirs so that user's + // environments do not interfere with tests. This allows a special env variable during + // tests so that custom conf dirs can be used by unit tests. + val confDirs = Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR") ++ + (if (Utils.isTesting) Seq("SPARK_TEST_HADOOP_CONF_DIR") else Nil) + + confDirs.foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { @@ -753,7 +749,7 @@ private[spark] class Client( // Save the YARN configuration into a separate file that will be overlayed on top of the // cluster's Hadoop conf. - confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE)) + confStream.putNextEntry(new ZipEntry(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE)) hadoopConf.writeXml(confStream) confStream.closeEntry() @@ -787,11 +783,6 @@ private[spark] class Client( populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (loginFromKeytab) { - val credentialsFile = "credentials-" + UUID.randomUUID().toString - sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) - logInfo(s"Credentials file set to: $credentialsFile") - } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -901,7 +892,9 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten @@ -923,7 +916,9 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) @@ -1008,25 +1003,6 @@ private[spark] class Client( amContainer } - def setupCredentials(): Unit = { - loginFromKeytab = sparkConf.contains(PRINCIPAL.key) - if (loginFromKeytab) { - principal = sparkConf.get(PRINCIPAL).get - keytab = sparkConf.get(KEYTAB).orNull - - require(keytab != null, "Keytab must be specified when principal is specified.") - logInfo("Attempting to login to the Kerberos" + - s" using principal: $principal and keytab: $keytab") - val f = new File(keytab) - // Generate a file name that can be used for the keytab file, that does not conflict - // with any user file. - amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(PRINCIPAL.key, principal) - } - // Defensive copy of the credentials - credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials) - } - /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, @@ -1220,10 +1196,6 @@ private object Client extends Logging { // Name of the file in the conf archive containing Spark configuration. val SPARK_CONF_FILE = "__spark_conf__.properties" - // Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the - // cluster's Hadoop config. - val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" - // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" @@ -1421,15 +1393,20 @@ private object Client extends Logging { } /** - * Return whether the two file systems are the same. + * Return whether two URI represent file system are the same */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() + private[spark] def compareUri(srcUri: URI, dstUri: URI): Boolean = { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + val srcAuthority = srcUri.getAuthority() + val dstAuthority = dstUri.getAuthority() + if (srcAuthority != null && !srcAuthority.equalsIgnoreCase(dstAuthority)) { + return false + } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() @@ -1447,6 +1424,17 @@ private object Client extends Logging { } Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() + + } + + /** + * Return whether the two file systems are the same. + */ + protected def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + + compareUri(srcUri, dstUri) } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 3f4d236571ff..a2a18cdff65a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -141,7 +141,8 @@ private[yarn] class ExecutorRunnable( // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) @@ -220,12 +221,6 @@ private[yarn] class ExecutorRunnable( val env = new HashMap[String, String]() Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, @@ -233,6 +228,20 @@ private[yarn] class ExecutorRunnable( ) val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } + + sparkConf.getExecutorEnv.foreach { case (key, value) => + if (key == Environment.CLASSPATH.name()) { + // If the key of env variable is CLASSPATH, we assume it is a path and append it. + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } else { + // For other env variables, simply overwrite the value. + env(key) = value + } + } + // Add log urls container.foreach { c => sys.env.get("SPARK_USER").foreach { user => @@ -245,8 +254,6 @@ private[yarn] class ExecutorRunnable( } } - System.getenv().asScala.filterKeys(_.startsWith("SPARK")) - .foreach { case (k, v) => env(k) = v } env } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 506adb363aa9..ebee3d431744 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -81,7 +81,8 @@ private[yarn] class YarnAllocator( private val releasedContainers = Collections.newSetFromMap[ContainerId]( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - private val numExecutorsRunning = new AtomicInteger(0) + private val runningExecutors = Collections.newSetFromMap[String]( + new ConcurrentHashMap[String, java.lang.Boolean]()) private val numExecutorsStarting = new AtomicInteger(0) @@ -166,7 +167,7 @@ private[yarn] class YarnAllocator( clock = newClock } - def getNumExecutorsRunning: Int = numExecutorsRunning.get() + def getNumExecutorsRunning: Int = runningExecutors.size() def getNumExecutorsFailed: Int = synchronized { val endTime = clock.getTimeMillis() @@ -242,12 +243,11 @@ private[yarn] class YarnAllocator( * Request that the ResourceManager release the container running the specified executor. */ def killExecutor(executorId: String): Unit = synchronized { - if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.get(executorId).get - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - logWarning(s"Attempted to kill unknown executor $executorId!") + executorIdToContainer.get(executorId) match { + case Some(container) if !releasedContainers.contains(container.getId) => + internalReleaseContainer(container) + runningExecutors.remove(executorId) + case _ => logWarning(s"Attempted to kill unknown executor $executorId!") } } @@ -274,7 +274,7 @@ private[yarn] class YarnAllocator( "Launching executor count: %d. Cluster resources: %s.") .format( allocatedContainers.size, - numExecutorsRunning.get, + runningExecutors.size, numExecutorsStarting.get, allocateResponse.getAvailableResources)) @@ -286,7 +286,7 @@ private[yarn] class YarnAllocator( logDebug("Completed %d containers".format(completedContainers.size)) processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." - .format(completedContainers.size, numExecutorsRunning.get)) + .format(completedContainers.size, runningExecutors.size)) } } @@ -300,9 +300,9 @@ private[yarn] class YarnAllocator( val pendingAllocate = getPendingAllocate val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - - numExecutorsStarting.get - numExecutorsRunning.get + numExecutorsStarting.get - runningExecutors.size logDebug(s"Updating resource requests, target: $targetNumExecutors, " + - s"pending: $numPendingAllocate, running: ${numExecutorsRunning.get}, " + + s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") if (missing > 0) { @@ -502,7 +502,7 @@ private[yarn] class YarnAllocator( s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { - numExecutorsRunning.incrementAndGet() + runningExecutors.add(executorId) numExecutorsStarting.decrementAndGet() executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -513,7 +513,7 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (numExecutorsRunning.get < targetNumExecutors) { + if (runningExecutors.size() < targetNumExecutors) { numExecutorsStarting.incrementAndGet() if (launchContainers) { launcherPool.execute(new Runnable { @@ -554,7 +554,7 @@ private[yarn] class YarnAllocator( } else { logInfo(("Skip launching executorRunnable as running executors count: %d " + "reached target executors count: %d.").format( - numExecutorsRunning.get, targetNumExecutors)) + runningExecutors.size, targetNumExecutors)) } } } @@ -569,7 +569,11 @@ private[yarn] class YarnAllocator( val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() + containerIdToExecutorId.get(containerId) match { + case Some(executorId) => runningExecutors.remove(executorId) + case None => logWarning(s"Cannot find executorId for container: ${containerId.toString}") + } + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, onHostStr, @@ -736,7 +740,8 @@ private object YarnAllocator { def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) val diag = if (matcher.find()) " " + matcher.group() + "." else "" - ("Container killed by YARN for exceeding memory limits." + diag - + " Consider boosting spark.yarn.executor.memoryOverhead.") + s"Container killed by YARN for exceeding memory limits. $diag " + + "Consider boosting spark.yarn.executor.memoryOverhead or " + + "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index c1ae12aabb8c..17234b120ae1 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -71,7 +70,8 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, + trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f406fabd6186..8eda6cb1277c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.CredentialUpdater import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils @@ -38,8 +37,6 @@ import org.apache.spark.util.Utils object YarnSparkHadoopUtil { - private var credentialUpdater: CredentialUpdater = _ - // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. @@ -206,21 +203,4 @@ object YarnSparkHadoopUtil { filesystemsToAccess + stagingFS } - def startCredentialUpdater(sparkConf: SparkConf): Unit = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) - credentialUpdater.start() - } - - def stopCredentialUpdater(): Unit = { - if (credentialUpdater != null) { - credentialUpdater.stop() - credentialUpdater = null - } - } - } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 3ba3ae5ab440..1a99b3bd5767 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -231,16 +231,6 @@ package object config { /* Security configuration. */ - private[spark] val CREDENTIAL_FILE_MAX_COUNT = - ConfigBuilder("spark.yarn.credentials.file.retention.count") - .intConf - .createWithDefault(5) - - private[spark] val CREDENTIALS_FILE_MAX_RETENTION = - ConfigBuilder("spark.yarn.credentials.file.retention.days") - .intConf - .createWithDefault(5) - private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + "fs.defaultFS does not need to be listed here.") @@ -271,11 +261,6 @@ package object config { /* Private configs. */ - private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") - .internal() - .stringConf - .createWithDefault(null) - // Internal config to propagate the location of the user's jar to the driver/executors private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") .internal() @@ -329,16 +314,6 @@ package object config { .stringConf .createOptional - private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - - private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1m") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index eaf2cff111a4..bc8d47dbd54c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -18,221 +18,160 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils /** - * The following methods are primarily meant to make sure long-running apps like Spark - * Streaming apps can run without interruption while accessing secured services. The - * scheduleLoginFromKeytab method is called on the AM to get the new credentials. - * This method wakes up a thread that logs into the KDC - * once 75% of the renewal interval of the original credentials used for the container - * has elapsed. It then obtains new credentials and writes them to HDFS in a - * pre-specified location - the prefix of which is specified in the sparkConf by - * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc. - * - each update goes to a new file, with a monotonically increasing suffix), also the - * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater. - * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed. + * A manager tasked with periodically updating delegation tokens needed by the application. * - * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is - * called once 80% of the validity of the original credentials has elapsed. At that time the - * executor finds the credentials file with the latest timestamp and checks if it has read those - * credentials before (by keeping track of the suffix of the last file it read). If a new file has - * appeared, it will read the credentials and update the currently running UGI with it. This - * process happens again once 80% of the validity of this has expired. + * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run + * without interruption while accessing secured services. It periodically logs in to the KDC with + * user-provided credentials, and contacts all the configured secure services to obtain delegation + * tokens to be distributed to the rest of the application. + * + * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API + * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is + * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet + * needed. The check period can be overridden in the configuration. + * + * New delegation tokens are created once 75% of the renewal interval of the original tokens has + * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. + * The driver is tasked with distributing the tokens to other processes that might need them. */ private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { + hadoopConf: Configuration) extends Logging { - private var lastCredentialsFileSuffix = 0 + private val principal = sparkConf.get(PRINCIPAL).get + private val keytab = sparkConf.get(KEYTAB).get + private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - private val credentialRenewerThread: ScheduledExecutorService = + private val renewalExecutor: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - private val hadoopUtil = SparkHadoopUtil.get + private val driverRef = new AtomicReference[RpcEndpointRef]() - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) - private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT) - private val freshHadoopConf = - hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) + private val renewalTask = new Runnable() { + override def run(): Unit = { + updateTokensTask() + } + } - @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + def setDriverRef(ref: RpcEndpointRef): Unit = { + driverRef.set(ref) + } /** - * Schedule a login from the keytab and principal set using the --principal and --keytab - * arguments to spark-submit. This login happens only when the credentials of the current user - * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from - * SparkConf to do the login. This method is a no-op in non-YARN mode. + * Start the token renewer. Upon start, the renewer will: * + * - log in the configured user, and set up a task to keep that user's ticket renewed + * - obtain delegation tokens from all available providers + * - schedule a periodic task to update the tokens when needed. + * + * @return The newly logged in user. */ - private[spark] def scheduleLoginFromKeytab(): Unit = { - val principal = sparkConf.get(PRINCIPAL).get - val keytab = sparkConf.get(KEYTAB).get - - /** - * Schedule re-login and creation of new credentials. If credentials have already expired, this - * method will synchronously create new ones. - */ - def scheduleRenewal(runnable: Runnable): Unit = { - // Run now! - val remainingTime = timeOfNextRenewal - System.currentTimeMillis() - if (remainingTime <= 0) { - logInfo("Credentials have expired, creating new ones now.") - runnable.run() - } else { - logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + def start(): UserGroupInformation = { + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + val ugi = doLogin() + + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() } } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) - // This thread periodically runs on the AM to update the credentials on HDFS. - val credentialRenewerRunnable = - new Runnable { - override def run(): Unit = { - try { - writeNewCredentialsToHDFS(principal, keytab) - cleanupOldFiles() - } catch { - case e: Exception => - // Log the error and try to write new tokens back in an hour - logWarning("Failed to write out new credentials to HDFS, will try again in an " + - "hour! If this happens too often tasks will fail.", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) - return - } - scheduleRenewal(this) - } - } - // Schedule update of credentials. This handles the case of updating the credentials right now - // as well, since the renewal interval will be 0, and the thread will get scheduled - // immediately. - scheduleRenewal(credentialRenewerRunnable) + val creds = obtainTokensAndScheduleRenewal(ugi) + ugi.addCredentials(creds) + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. Explicitly avoid overwriting tokens that already exist in the current user's + // credentials, since those were freshly obtained above (see SPARK-23361). + val existing = ugi.getCredentials() + existing.mergeAll(originalCreds) + ugi.addCredentials(existing) + + ugi + } + + def stop(): Unit = { + renewalExecutor.shutdown() + } + + private def scheduleRenewal(delay: Long): Unit = { + val _delay = math.max(0, delay) + logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") + renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) } - // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At - // least numFilesToKeep files are kept for safety - private def cleanupOldFiles(): Unit = { - import scala.concurrent.duration._ + /** + * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself + * to fetch the next set of tokens when needed. + */ + private def updateTokensTask(): Unit = { try { - val remoteFs = FileSystem.get(freshHadoopConf) - val credentialsPath = new Path(credentialsFile) - val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .dropRight(numFilesToKeep) - .takeWhile(_.getModificationTime < thresholdTime) - .foreach(x => remoteFs.delete(x.getPath, true)) + val freshUGI = doLogin() + val creds = obtainTokensAndScheduleRenewal(freshUGI) + val tokens = SparkHadoopUtil.get.serialize(creds) + + val driver = driverRef.get() + if (driver != null) { + logInfo("Updating delegation tokens.") + driver.send(UpdateDelegationTokens(tokens)) + } else { + // This shouldn't really happen, since the driver should register way before tokens expire + // (or the AM should time out the application). + logWarning("Delegation tokens close to expiration but no driver has registered yet.") + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) + } } catch { - // Such errors are not fatal, so don't throw. Make sure they are logged though case e: Exception => - logWarning("Error while attempting to cleanup old credentials. If you are seeing many " + - "such warnings there may be an issue with your HDFS cluster.", e) + val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + + " If this happens too often tasks will fail.", e) + scheduleRenewal(delay) } } - private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = { - // Keytab is copied by YARN to the working directory of the AM, so full path is - // not needed. - - // HACK: - // HDFS will not issue new delegation tokens, if the Credentials object - // passed in already has tokens for that FS even if the tokens are expired (it really only - // checks if there are tokens for the service, and not if they are valid). So the only real - // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the current user's Credentials. - // So: - // - we login as a different user and get the UGI - // - use that UGI to get the tokens (see doAs block below) - // - copy the tokens over to the current user's credentials (this will overwrite the tokens - // in the current user's Credentials object for this FS). - // The login to KDC happens each time new tokens are required, but this is rare enough to not - // have to worry about (like once every day or so). This makes this code clearer than having - // to login and then relogin every time (the HDFS API may not relogin since we don't use this - // UGI directly for HDFS communication. - logInfo(s"Attempting to login to KDC using principal: $principal") - val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - logInfo("Successfully logged into KDC.") - val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(credentialsFile) - val dst = credentialsPath.getParent - var nearestNextRenewalTime = Long.MaxValue - keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { - // Get a copy of the credentials - override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainDelegationTokens( - freshHadoopConf, - tempCreds) - null + /** + * Obtain new delegation tokens from the available providers. Schedules a new task to fetch + * new tokens before the new set expires. + * + * @return Credentials containing the new tokens. + */ + private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { + ugi.doAs(new PrivilegedExceptionAction[Credentials]() { + override def run(): Credentials = { + val creds = new Credentials() + val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds) + + val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) - + System.currentTimeMillis() + scheduleRenewal(timeToWait) + creds } }) - - val currTime = System.currentTimeMillis() - val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) { - // If next renewal time is earlier than current time, we set next renewal time to current - // time, this will trigger next renewal immediately. Also set next update time to current - // time. There still has a gap between token renewal and update will potentially introduce - // issue. - logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " + - s"current time ($currTime), which is unexpected, please check your credential renewal " + - "related configurations in the target services.") - timeOfNextRenewal = currTime - currTime - } else { - // Next valid renewal time is about 75% of credential renewal time, and update time is - // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) - } - - // Add the temp credentials back to the original ones. - UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(freshHadoopConf) - // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM - // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file - // and update the lastCredentialsFileSuffix. - if (lastCredentialsFileSuffix == 0) { - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.foreach { status => - lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) - } - } - val nextSuffix = lastCredentialsFileSuffix + 1 - - val tokenPathStr = - credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - nextSuffix - val tokenPath = new Path(tokenPathStr) - val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - - logInfo("Writing out delegation tokens to " + tempTokenPath.toString) - val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) - logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") - remoteFs.rename(tempTokenPath, tokenPath) - logInfo("Delegation token file rename complete.") - lastCredentialsFileSuffix = nextSuffix } - def stop(): Unit = { - credentialRenewerThread.shutdown() + private def doLogin(): UserGroupInformation = { + logInfo(s"Attempting to login to KDC using principal: $principal") + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + ugi } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala deleted file mode 100644 index fe173dffc22a..000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.concurrent.{Executors, TimeUnit} - -import scala.util.control.NonFatal - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -private[spark] class CredentialUpdater( - sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { - - @volatile private var lastCredentialsFileSuffix = 0 - - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val freshHadoopConf = - SparkHadoopUtil.get.getConfBypassingFSCache( - hadoopConf, new Path(credentialsFile).toUri.getScheme) - - private val credentialUpdater = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) - - // This thread wakes up and picks up new credentials from HDFS, if any. - private val credentialUpdaterRunnable = - new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) - } - - /** Start the credential updater task */ - def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) - val remainingTime = startTime - System.currentTimeMillis() - if (remainingTime <= 0) { - credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) - } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") - credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) - } - } - - private def updateCredentialsIfRequired(): Unit = { - val timeToNextUpdate = try { - val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(freshHadoopConf) - SparkHadoopUtil.get.listFilesSorted( - remoteFs, credentialsFilePath.getParent, - credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.map { credentialsStatus => - val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) - if (suffix > lastCredentialsFileSuffix) { - logInfo("Reading new credentials from " + credentialsStatus.getPath) - val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) - lastCredentialsFileSuffix = suffix - UserGroupInformation.getCurrentUser.addCredentials(newCredentials) - logInfo("Credentials updated from credentials file.") - - val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis()) - if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime - } else { - // If current credential file is older than expected, sleep 1 hour and check again. - TimeUnit.HOURS.toMillis(1) - } - }.getOrElse { - // Wait for 1 minute to check again if there's no credential file currently - TimeUnit.MINUTES.toMillis(1) - } - } catch { - // Since the file may get deleted while we are reading it, catch the Exception and come - // back in an hour to try again - case NonFatal(e) => - logWarning("Error while trying to update credentials, will try again in 1 hour", e) - TimeUnit.HOURS.toMillis(1) - } - - logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") - credentialUpdater.schedule( - credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) - } - - private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { - val stream = remoteFs.open(tokenPath) - try { - val newCredentials = new Credentials() - newCredentials.readTokenStorageStream(stream) - newCredentials - } finally { - stream.close() - } - } - - private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = { - val name = credentialsPath.getName - val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - val slice = name.substring(0, index) - val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - name.substring(last2index + 1, index).toLong - } - - def stop(): Unit = { - credentialUpdater.shutdown() - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 163cfb4eb862..d4eeb6bbcf88 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -22,11 +22,11 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -37,11 +37,10 @@ import org.apache.spark.util.Utils */ private[yarn] class YARNHadoopDelegationTokenManager( sparkConf: SparkConf, - hadoopConf: Configuration, - fileSystems: Configuration => Set[FileSystem]) extends Logging { + hadoopConf: Configuration) extends Logging { - private val delegationTokenManager = - new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) // public for testing val credentialProviders = getCredentialProviders diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0c6206eebe41..06e54a2eaf95 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -62,12 +62,6 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() - // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver - // reads the credentials from HDFS, just like the executors and updates its own credentials - // cache. - if (conf.contains("spark.yarn.credentials.file")) { - YarnSparkHadoopUtil.startCredentialUpdater(conf) - } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -153,7 +147,6 @@ private[spark] class YarnClientSchedulerBackend( client.reportLauncherState(SparkAppHandle.State.FINISHED) super.stop() - YarnSparkHadoopUtil.stopCredentialUpdater() client.stop() logInfo("Stopped") } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bb615c36cd97..63bea3e7a500 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,9 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -70,6 +72,7 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -263,8 +266,13 @@ private[spark] abstract class YarnSchedulerBackend( logWarning(s"Requesting driver to remove executor $executorId for reason $reason") driverEndpoint.send(r) } - } + case u @ UpdateDelegationTokens(tokens) => + // Add the tokens to the current user and send a message to the scheduler so that it + // notifies all registered executors of the new tokens. + SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf) + driverEndpoint.send(u) + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 000000000000..695a82f3583e --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 9d5f5eb62111..7fa597167f3f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -357,6 +357,39 @@ class ClientSuite extends SparkFunSuite with Matchers { sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName))) } + private val matching = Seq( + ("files URI match test1", "file:///file1", "file:///file2"), + ("files URI match test2", "file:///c:file1", "file://c:file2"), + ("files URI match test3", "file://host/file1", "file://host/file2"), + ("wasb URI match test", "wasb://bucket1@user", "wasb://bucket1@user/"), + ("hdfs URI match test", "hdfs:/path1", "hdfs:/path1") + ) + + matching.foreach { t => + test(t._1) { + assert(Client.compareUri(new URI(t._2), new URI(t._3)), + s"No match between ${t._2} and ${t._3}") + } + } + + private val unmatching = Seq( + ("files URI unmatch test1", "file:///file1", "file://host/file2"), + ("files URI unmatch test2", "file://host/file1", "file:///file2"), + ("files URI unmatch test3", "file://host/file1", "file://host2/file2"), + ("wasb URI unmatch test1", "wasb://bucket1@user", "wasb://bucket2@user/"), + ("wasb URI unmatch test2", "wasb://bucket1@user", "wasb://bucket1@user2/"), + ("s3 URI unmatch test", "s3a://user@pass:bucket1/", "s3a://user2@pass2:bucket1/"), + ("hdfs URI unmatch test1", "hdfs://namenode1/path1", "hdfs://namenode1:8080/path2"), + ("hdfs URI unmatch test2", "hdfs://namenode1:8020/path1", "hdfs://namenode1:8080/path2") + ) + + unmatching.foreach { t => + test(t._1) { + assert(!Client.compareUri(new URI(t._2), new URI(t._3)), + s"match between ${t._2} and ${t._3}") + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index cb1e3c526851..525abb6f2b35 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -251,11 +251,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (1) } + test("kill same executor multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val executorToKill = handler.executorIdToContainer.keys.head + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty) + handler.updateResourceRequests() + handler.getPendingAllocate.size should be (1) + } + + test("process same completed container multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val statuses = Seq(container1, container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.processCompletedContainers(statuses) + handler.getNumExecutorsRunning should be (0) + + } + test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() @@ -272,7 +316,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (2) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 061f653b97b7..a129be7c06b5 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -45,8 +45,7 @@ import org.apache.spark.util.Utils /** * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN - * applications, and require the Spark assembly to be built before they can be successfully - * run. + * applications. */ @ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { @@ -115,12 +114,25 @@ class YarnClusterSuite extends BaseYarnClusterSuite { )) } - test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") { + test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)") { + // Create a custom hadoop config file, to make sure it's contents are propagated to the driver. + val customConf = Utils.createTempDir() + val coreSite = """ + | + | + | spark.test.key + | testvalue + | + | + |""".stripMargin + Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8) + val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(false, mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), - appArgs = Seq("key=value", result.getAbsolutePath()), - extraConf = Map("spark.hadoop.key" -> "value")) + appArgs = Seq("key=value", "spark.test.key=testvalue", result.getAbsolutePath()), + extraConf = Map("spark.hadoop.key" -> "value"), + extraEnv = Map("SPARK_TEST_HADOOP_CONF_DIR" -> customConf.getAbsolutePath())) checkResult(finalState, result) } @@ -152,7 +164,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } test("run Python application in yarn-cluster mode using " + - " spark.yarn.appMasterEnv to override local envvar") { + "spark.yarn.appMasterEnv to override local envvar") { testPySpark( clientMode = false, extraConf = Map( @@ -213,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { finalState should be (SparkAppHandle.State.FAILED) } + test("executor env overwrite AM env in client mode") { + testExecutorEnv(true) + } + + test("executor env overwrite AM env in cluster mode") { + testExecutorEnv(false) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -293,6 +313,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, executorResult, "OVERRIDDEN") } + private def testExecutorEnv(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass), + appArgs = Seq(result.getAbsolutePath), + extraConf = Map( + "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val", + "spark.executorEnv.TEST_ENV" -> "executor_val" + ) + ) + checkResult(finalState, result, "true") + } } private[spark] class SaveExecutorInfo extends SparkListener { @@ -320,13 +351,13 @@ private object YarnClusterDriverWithFailure extends Logging with Matchers { private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers { def main(args: Array[String]): Unit = { - if (args.length != 2) { + if (args.length < 2) { // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file] + |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value]+ [result file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -336,11 +367,16 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn test using SparkHadoopUtil's conf")) - val kv = args(0).split("=") - val status = new File(args(1)) + val kvs = args.take(args.length - 1).map { kv => + val parsed = kv.split("=") + (parsed(0), parsed(1)) + } + val status = new File(args.last) var result = "failure" try { - SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1)) + kvs.foreach { case (k, v) => + SparkHadoopUtil.get.conf.get(k) should be (v) + } result = "success" } finally { Files.write(result, status, StandardCharsets.UTF_8) @@ -381,7 +417,9 @@ private object YarnClusterDriver extends Logging with Matchers { // Verify that the config archive is correctly placed in the classpath of all containers. val confFile = "/" + Client.SPARK_CONF_FILE - assert(getClass().getResource(confFile) != null) + if (conf.getOption(SparkLauncher.DEPLOY_MODE) == Some("cluster")) { + assert(getClass().getResource(confFile) != null) + } val configFromExecutors = sc.parallelize(1 to 4, 4) .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull } .collect() @@ -507,3 +545,20 @@ private object SparkContextTimeoutApp { } } + +private object ExecutorEnvTestApp { + + def main(args: Array[String]): Unit = { + val status = args(0) + val sparkConf = new SparkConf() + val sc = new SparkContext(sparkConf) + val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap + val result = sparkConf.getExecutorEnv.forall { case (k, v) => + executorEnvs.get(k).contains(v) + } + + Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + sc.stop() + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 3c7cdc0f1dab..9fa749b14c98 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.security.Credentials import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { private var credentialManager: YARNHadoopDelegationTokenManager = null @@ -36,11 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers } test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - + credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) credentialManager.credentialProviders.get("yarn-test") should not be (None) } } diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh deleted file mode 100755 index b3137598692d..000000000000 --- a/sbin/build-push-docker-images.sh +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env bash - -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This script builds and pushes docker images when run from a release of Spark -# with Kubernetes support. - -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) - -function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . - for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . - done -} - - -function push { - for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} - done -} - -function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" -} - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage - exit 0 -fi - -while getopts r:t: option -do - case "${option}" - in - r) REPO=${OPTARG};; - t) TAG=${OPTARG};; - esac -done - -if [ -z "$REPO" ] || [ -z "$TAG" ]; then - usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7bdd3fac773a..e65e3aafe5b5 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -93,7 +93,7 @@ This file is divided into 3 sections: - + @@ -229,7 +229,7 @@ This file is divided into 3 sections: extractOpt - Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter is slower. diff --git a/sql/README.md b/sql/README.md index fe1d352050c0..70cc7c637b58 100644 --- a/sql/README.md +++ b/sql/README.md @@ -6,7 +6,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`. diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9e2ced30407d..7d23637e2834 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml @@ -134,7 +134,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d5..f7f921ec22c3 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike @@ -137,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier @@ -394,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* lateralView* + : FROM relation (',' relation)* (pivotClause | lateralView*)? ; aggregation @@ -409,6 +413,10 @@ groupingSet | expression ; +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + ; + lateralView : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? ; @@ -721,7 +729,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -741,7 +749,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -756,6 +764,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -801,6 +810,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d18542b188f7..d5d934bc91ca 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,6 +27,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -55,9 +56,19 @@ public final class UnsafeArrayData extends ArrayData { public static int calculateHeaderPortionInBytes(int numFields) { + return (int)calculateHeaderPortionInBytes((long)numFields); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; @@ -230,7 +241,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 71c086029cc5..29a1411241cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -414,7 +415,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index f37ef83ad92b..8e9c0a2e9dc8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,8 @@ */ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off /** @@ -71,13 +72,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWords(Object base, long offset, int length) { - return hashUnsafeWords(base, offset, length, seed); + public long hashUnsafeWordsBlock(MemoryBlock mb) { + return hashUnsafeWordsBlock(mb, seed); } - public static long hashUnsafeWords(Object base, long offset, int length, long seed) { - assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWords(base, offset, length, seed); + public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { + assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWordsBlock(mb, seed); return fmix(hash); } @@ -85,26 +86,36 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWords(base, offset, length, seed); + long hash = hashBytesByWordsBlock(mb, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } return fmix(hash); } + public static long hashUTF8String(UTF8String str, long seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + } + private static long fmix(long hash) { hash ^= hash >>> 33; hash *= PRIME64_2; @@ -114,30 +125,31 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWords(Object base, long offset, int length, long seed) { - long end = offset + length; + private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); long hash; if (length >= 32) { - long limit = end - 32; + long limit = length - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 += mb.getLong(offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 += mb.getLong(offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 += mb.getLong(offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 += mb.getLong(offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -178,9 +190,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long hash += length; - long limit = end - 8; + long limit = length - 8; while (offset <= limit) { - long k1 = Platform.getLong(base, offset); + long k1 = mb.getLong(offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12..537ef244b7e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,25 +30,21 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +82,23 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] getBuffer() { + return buffer; + } + + int getCursor() { + return cursor; + } + + void increaseCursor(int val) { + cursor += val; + } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 000000000000..f8000d78cd1b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private ByteArrayMemoryBlock buffer; + private int length = 0; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new ByteArrayMemoryBlock(16); + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - length) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int requestedSize = length + neededSize; + if (buffer.size() < requestedSize) { + int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; + final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); + MemoryBlock.copyMemory(buffer, tmp, length); + buffer = tmp; + } + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); + length += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer.getByteArray(), 0, length); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 791e8d80e6cb..a78dd970d23e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -30,16 +28,14 @@ * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. */ -public class UnsafeArrayWriter { - - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; +public final class UnsafeArrayWriter extends UnsafeWriter { // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { + super(writer.getBufferHolder()); + this.elementSize = elementSize; + } + + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,130 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); - } - holder.cursor += (headerInBytes + fixedPartInBytes); - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } + increaseCursor(headerInBytes + fixedPartInBytes); } - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, long currentCursor, int size) { - assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } - public void setNullBoolean(int ordinal) { + public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); + writeByte(getElementOffset(ordinal), (byte)0); } - public void setNullByte(int ordinal) { + public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + writeShort(getElementOffset(ordinal), (short)0); } - public void setNullShort(int ordinal) { + public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + writeInt(getElementOffset(ordinal), 0); } - public void setNullInt(int ordinal) { + public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + writeLong(getElementOffset(ordinal), 0); } - public void setNullLong(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); - } - - public void setNullFloat(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); - } - - public void setNullDouble(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); - } - - public void setNull(int ordinal) { setNullLong(ordinal); } + public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -198,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + increaseCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 5d9515c0725d..71c49d8ed017 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -31,38 +28,67 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. */ -public class UnsafeRowWriter { +public final class UnsafeRowWriter extends UnsafeWriter { + + private final UnsafeRow row; - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); + } + + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ + public UnsafeRow getRow() { + row.setTotalSize(totalSize()); + return row; } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -72,92 +98,86 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } - } - - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } - public long getFieldOffset(int ordinal) { - return startingOffset + nullBitsSize + 8 * ordinal; + @Override + public void setNull1Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull2Bytes(int ordinal) { + setNullAt(ordinal); } - public void setOffsetAndSize(int ordinal, long size) { - setOffsetAndSize(ordinal, holder.cursor, size); + @Override + public void setNull4Bytes(int ordinal) { + setNullAt(ordinal); } - public void setOffsetAndSize(int ordinal, long currentCursor, long size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | size; + @Override + public void setNull8Bytes(int ordinal) { + setNullAt(ordinal); + } - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); + public long getFieldOffset(int ordinal) { + return startingOffset + nullBitsSize + 8 * ordinal; } public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + writeLong(offset, 0L); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + writeLong(offset, 0L); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + writeLong(offset, 0L); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + writeLong(offset, 0L); + writeInt(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + writeLong(offset, 0); + writeFloat(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + if (input != null && input.changePrecision(precision, scale)) { + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -165,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update - setOffsetAndSize(ordinal, 0L); + setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + increaseCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java new file mode 100644 index 000000000000..278165500200 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class for writing Unsafe* structures. + */ +public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] getBuffer() { + return holder.getBuffer(); + } + + public final void reset() { + holder.reset(); + } + + public final int totalSize() { + return holder.totalSize(); + } + + public final void grow(int neededSize) { + holder.grow(neededSize); + } + + public final int cursor() { + return holder.getCursor(); + } + + public final void increaseCursor(int val) { + holder.increaseCursor(val); + } + + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); + } + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + + public abstract void setNull1Bytes(int ordinal); + public abstract void setNull2Bytes(int ordinal); + public abstract void setNull4Bytes(int ordinal); + public abstract void setNull8Bytes(int ordinal); + + public abstract void write(int ordinal, boolean value); + public abstract void write(int ordinal, byte value); + public abstract void write(int ordinal, short value); + public abstract void write(int ordinal, int value); + public abstract void write(int ordinal, long value); + public abstract void write(int ordinal, float value); + public abstract void write(int ordinal, double value); + public abstract void write(int ordinal, Decimal input, int precision, int scale); + + public final void write(int ordinal, UTF8String input) { + writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes()); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes); + } + + private void writeUnalignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + grow(roundedSize); + zeroOutPaddingBytes(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + increaseCursor(16); + } + + public final void write(int ordinal, UnsafeRow row) { + writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); + } + + public final void write(int ordinal, UnsafeMapData map) { + writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes()); + } + + public final void write(UnsafeArrayData array) { + // Unsafe arrays both can be written as a regular array field or as part of a map. This makes + // updating the offset and size dependent on the code path, this is why we currently do not + // provide an method for writing unsafe arrays that also updates the size and offset. + int numBytes = array.getSizeInBytes(); + grow(numBytes); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + getBuffer(), + cursor(), + numBytes); + increaseCursor(numBytes); + } + + private void writeAlignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + grow(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(numBytes); + } + + protected final void writeBoolean(long offset, boolean value) { + Platform.putBoolean(getBuffer(), offset, value); + } + + protected final void writeByte(long offset, byte value) { + Platform.putByte(getBuffer(), offset, value); + } + + protected final void writeShort(long offset, short value) { + Platform.putShort(getBuffer(), offset, value); + } + + protected final void writeInt(long offset, int value) { + Platform.putInt(getBuffer(), offset, value); + } + + protected final void writeLong(long offset, long value) { + Platform.putLong(getBuffer(), offset, value); + } + + protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(getBuffer(), offset, value); + } + + protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(getBuffer(), offset, value); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java new file mode 100644 index 000000000000..bb77b5bf6de2 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; + +public final class RecordBinaryComparator extends RecordComparator { + + // TODO(jiangxb) Add test suite for this. + @Override + public int compare( + Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { + int i = 0; + int res = 0; + + // If the arrays have different length, the longer one is larger. + if (leftLen != rightLen) { + return leftLen - rightLen; + } + + // The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since + // we have guaranteed `leftLen` == `rightLen`. + + // check if stars align and we can get both offsets to be aligned + if ((leftOff % 8) == (rightOff % 8)) { + while ((leftOff + i) % 8 != 0 && i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + } + // for architectures that support unaligned accesses, chew it up 8 bytes at a time + if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { + while (i <= leftLen - 8) { + res = (int) ((Platform.getLong(leftObj, leftOff + i) - + Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); + if (res != 0) return res; + i += 8; + } + } + // this will finish off the unaligned comparisons, or do the entire aligned comparison + // whichever is needed. + while (i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + + // The two arrays are equal. + return 0; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 6b002f0d3f8e..1b2f5eee5ccd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.function.Supplier; import scala.collection.AbstractIterator; import scala.collection.Iterator; @@ -56,26 +57,50 @@ public abstract static class PrefixComputer { public static class Prefix { /** Key prefix value, or the null prefix value if isNull = true. **/ - long value; + public long value; /** Whether the key is null. */ - boolean isNull; + public boolean isNull; } /** * Computes prefix for the given row. For efficiency, the returned object may be reused in * further calls to a given PrefixComputer. */ - abstract Prefix computePrefix(InternalRow row); + public abstract Prefix computePrefix(InternalRow row); } - public UnsafeExternalRowSorter( + public static UnsafeExternalRowSorter createWithRecordComparator( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { + Supplier recordComparatorSupplier = + () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); @@ -85,7 +110,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - () -> new RowComparator(ordering, schema.length()), + recordComparatorSupplier, prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -206,7 +231,13 @@ private static final class RowComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1, 0); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 2800b3068f87..470c128ee6c3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** - * OutputMode is used to what data will be written to a streaming sink when there is + * OutputMode describes what data will be written to a streaming sink when there is * new data available in a streaming DataFrame/Dataset. * * @since 2.0.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 29110640d64f..274d75e680f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -119,4 +119,28 @@ object InternalRow { case v: MapData => v.copy() case _ => value } + + /** + * Returns an accessor for an `InternalRow` with given data type. The returned accessor + * actually takes a `SpecializedGetters` input because it can be generalized to other classes + * that implements `SpecializedGetters` (e.g., `ArrayData`) too. + */ + def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType) + case _ => (input, ordinal) => input.get(ordinal, dataType) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 65040f1af4b0..f9acc208b715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Constructor + +import org.apache.commons.lang3.reflect.ConstructorUtils + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -63,6 +67,7 @@ object ScalaReflection extends ScalaReflection { private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { + case t if t <:< definitions.NullTpe => NullType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType @@ -381,22 +386,22 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { + val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = deserializerFor( + deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor } } @@ -712,6 +717,9 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { + // this must be the first case, since all objects in scala are instances of Null, therefore + // Null type would wrongly match the first of them, which is Option as of now + case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) @@ -777,6 +785,15 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned. Otherwise, it returns `None`. + */ + def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + } + /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ @@ -790,6 +807,65 @@ object ScalaReflection extends ScalaReflection { "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while") + + val typeJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[DateType.InternalType], + TimestampType -> classOf[TimestampType.InternalType], + BinaryType -> classOf[BinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long] + ) + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 57f7a80bedc6..6d587abd8fd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -31,7 +31,7 @@ class TableAlreadyExistsException(db: String, table: String) extends AnalysisException(s"Table or view '$table' already exists in database '$db'") class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary table '$table' already exists") + extends AnalysisException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee..dfdcdbc1eb2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ @@ -52,6 +53,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +72,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +99,30 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + val analyzed = execute(plan) + try { + checkAnalysis(analyzed) + EliminateBarriers(analyzed) + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } + } + + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -150,6 +178,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: + ResolvedUuidExpressions :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -164,8 +193,7 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases, - EliminateBarriers) + CleanupAliases) ) /** @@ -176,7 +204,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -247,9 +275,9 @@ class Analyzer( case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -476,9 +504,20 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + // Check all aggregate expressions. + aggregates.foreach { e => + if (!isAggregateExpression(e)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$e'") + } + } + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) @@ -540,16 +579,20 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + private def isAggregateExpression(expr: Expression): Boolean = { + expr match { + case Alias(e, _) => isAggregateExpression(e) + case AggregateExpression(_, _, _, _) => true + case _ => false + } + } } /** @@ -597,10 +640,10 @@ class Analyzer( if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + - "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + - "aroud this.") + s"avoid errors. Increase the value of ${SQLConf.MAX_NESTED_VIEW_DEPTH.key} to work " + + "around this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -639,7 +682,7 @@ class Analyzer( // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exsits.") + s"database ${e.db} doesn't exist.") } } @@ -1269,7 +1312,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1435,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1493,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] @@ -1466,7 +1510,7 @@ class Analyzer( // to push down this ordering expression and can reference the original aggregate // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map { case (evaluated, order) => val index = originalAggExprs.indexWhere { case Alias(child, _) => child semanticEquals evaluated.child @@ -1509,7 +1553,7 @@ class Analyzer( } /** - * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] + * Extracts [[Generator]] from the projectList of a [[Project]] operator and creates [[Generate]] * operator under [[Project]]. * * This rule will throw [[AnalysisException]] for following cases: @@ -1967,6 +2011,20 @@ class Analyzer( } } + /** + * Set the seed for random number generation in Uuid expressions. + */ + object ResolvedUuidExpressions extends Rule[LogicalPlan] { + private lazy val random = new Random() + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case p if p.resolved => p + case p => p transformExpressionsUp { + case Uuid(None) => Uuid(Some(random.nextLong())) + } + } + } + /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the @@ -2011,7 +2069,11 @@ class Analyzer( WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } we.copy(windowSpec = s.copy(frameSpecification = frame)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bbcec5627bd4..90bda2a72ad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -153,11 +153,19 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + def isAggregateExpression(expr: Expression) = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + } + def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case aggExpr: AggregateExpression => - aggExpr.aggregateFunction.children.foreach { child => + case expr: Expression if isAggregateExpression(expr) => + val aggFunction = expr match { + case agg: AggregateExpression => agg.aggregateFunction + case udf: PythonUDF => udf + } + aggFunction.children.foreach { child => child.foreach { - case agg: AggregateExpression => + case expr: Expression if isAggregateExpression(expr) => failAnalysis( s"It is not allowed to use an aggregate function in the argument of " + s"another aggregate function. Please use the inner aggregate function " + @@ -348,6 +356,7 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { + case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index a8100b9b24aa..ab63131b0757 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -42,8 +43,10 @@ import org.apache.spark.sql.types._ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -56,6 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { @@ -93,41 +97,76 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) } - val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -137,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions } /** @@ -243,17 +279,35 @@ object DecimalPrecision extends TypeCoercionRule { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5ddb39822617..87b0911e150c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -308,7 +308,6 @@ object FunctionRegistry { expression[BitLength]("bit_length"), expression[Length]("char_length"), expression[Length]("character_length"), - expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), @@ -336,7 +335,6 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), - expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), @@ -395,6 +393,7 @@ object FunctionRegistry { expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), + expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), @@ -402,12 +401,23 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayJoin]("array_join"), + expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Slice]("slice"), + expression[Size]("cardinality"), expression[SortArray]("sort_array"), + expression[ArrayMin]("array_min"), + expression[ArrayMax]("array_max"), + expression[Reverse]("reverse"), + expression[Concat]("concat"), + expression[Flatten]("flatten"), CreateStruct.registryEntry, // misc functions @@ -526,7 +536,17 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - throw new AnalysisException(s"Invalid number of arguments for function $name") + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${params.length}") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2..b2817b0538a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -47,18 +47,19 @@ import org.apache.spark.sql.types._ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = - InConversion :: + InConversion(conf) :: WidenSetOperationTypes :: - PromoteStrings :: + PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -111,6 +112,14 @@ object TypeCoercion { StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) })) + case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + + case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) + Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) + case _ => None } @@ -126,7 +135,8 @@ object TypeCoercion { * is a String and the other is not. It also handles when one op is a Date and the * other is a Timestamp by making the target type to be String. */ - val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + private def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true @@ -134,11 +144,17 @@ object TypeCoercion { case (DateType, StringType) => Some(StringType) case (StringType, TimestampType) => Some(StringType) case (TimestampType, StringType) => Some(StringType) - case (TimestampType, DateType) => Some(StringType) - case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + // Cast to TimestampType when we compare DateType with TimestampType + // if conf.compareDateTimestampInTimestamp is true + // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true + case (TimestampType, DateType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + case (DateType, TimestampType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + // There is no proper decimal type we can pick, // using double type is the best we can do. // See SPARK-22469 for details. @@ -146,7 +162,7 @@ object TypeCoercion { case (s: StringType, n: DecimalType) => Some(DoubleType) case (l: StringType, r: AtomicType) if r != StringType => Some(r) - case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) case (l, r) => None } @@ -167,11 +183,27 @@ object TypeCoercion { }) } + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt match { + case StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) - case None => None - }) + // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal + // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. + // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, + // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) } /** @@ -312,7 +344,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends TypeCoercionRule { + case class PromoteStrings(conf: SQLConf) extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -341,8 +373,8 @@ object TypeCoercion { p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) - if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => - val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) @@ -373,7 +405,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends TypeCoercionRule { + case class InConversion(conf: SQLConf) extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -399,7 +431,7 @@ object TypeCoercion { val rhs = sub.output val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) + findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } @@ -496,6 +528,14 @@ object TypeCoercion { case None => a } + case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case None => c + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { @@ -684,6 +724,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub @@ -708,12 +776,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -730,7 +819,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -746,6 +835,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index b55043c27064..d3d6c636c4ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,8 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d336f801d077..71e23175168e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode @@ -294,7 +294,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu } else { val from = input.inputSet.map(_.name).mkString(", ") val targetString = target.get.mkString(".") - throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") + throw new AnalysisException(s"cannot resolve '$targetString.*' given input columns '$from'") } } @@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression with CodegenFallback { + extends UnaryExpression with NamedExpression with Unevaluable { override def name: String = throw new UnresolvedException(this, "name") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c..1a145c24d78c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 000000000000..2f009be5816f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd5..741dc46b0738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a12989623077..c390337c03ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,8 +54,8 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - val externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -70,8 +70,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf) { this( - externalCatalog, - new GlobalTempViewManager("global_temp"), + () => externalCatalog, + () => new GlobalTempViewManager("global_temp"), functionRegistry, conf, new Configuration(), @@ -87,6 +87,9 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } + lazy val externalCatalog = externalCatalogBuilder() + lazy val globalTempViewManager = globalTempViewManagerBuilder() + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") protected val tempViews = new mutable.HashMap[String, LogicalPlan] @@ -283,9 +286,13 @@ class SessionCatalog( * Create a metastore table in the database specified in `tableDefinition`. * If no such database is specified, create it in the current database. */ - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + def createTable( + tableDefinition: CatalogTable, + ignoreIfExists: Boolean, + validateLocation: Boolean = true): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) validateName(table) val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined @@ -295,15 +302,37 @@ class SessionCatalog( makeQualifiedPath(tableDefinition.storage.locationUri.get) tableDefinition.copy( storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), - identifier = TableIdentifier(table, Some(db))) + identifier = tableIdentifier) } else { - tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + tableDefinition.copy(identifier = tableIdentifier) } requireDbExists(db) + if (tableExists(newTableDefinition.identifier)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistsException(db = db, table = table) + } + } else if (validateLocation) { + validateTableLocation(newTableDefinition) + } externalCatalog.createTable(newTableDefinition, ignoreIfExists) } + def validateTableLocation(table: CatalogTable): Unit = { + // SPARK-19724: the default location of a managed table should be non-existent or empty. + if (table.tableType == CatalogTableType.MANAGED && + !conf.allowCreatingManagedTableUsingNonemptyLocation) { + val tableLocation = + new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier))) + val fs = tableLocation.getFileSystem(hadoopConf) + + if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) { + throw new AnalysisException(s"Can not create the managed table('${table.identifier}')" + + s". The associated location('${tableLocation.toString}') already exists.") + } + } + } + /** * Alter the metadata of an existing metastore table identified by `tableDefinition`. * @@ -988,8 +1017,11 @@ class SessionCatalog( // ------------------------------------------------------- /** - * Create a metastore function in the database specified in `funcDefinition`. + * Create a function in the database specified in `funcDefinition`. * If no such database is specified, create it in the current database. + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. */ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) @@ -1061,7 +1093,7 @@ class SessionCatalog( } /** - * Check if the specified function exists. + * Check if the function with the specified name exists */ def functionExists(name: FunctionIdentifier): Boolean = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 95b6fbb0cd61..f3e67dc4e975 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -21,7 +21,9 @@ import java.net.URI import java.util.Date import scala.collection.mutable +import scala.util.control.NonFatal +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** @@ -361,7 +363,7 @@ object CatalogTable { case class CatalogStatistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - colStats: Map[String, ColumnStat] = Map.empty) { + colStats: Map[String, CatalogColumnStat] = Map.empty) { /** * Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based @@ -369,7 +371,8 @@ case class CatalogStatistics( */ def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = { if (cboEnabled && rowCount.isDefined) { - val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _))) + val attrStats = AttributeMap(planOutput + .flatMap(a => colStats.get(a.name).map(a -> _.toPlanStat(a.name, a.dataType)))) // Estimate size as number of rows * row size. val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats) Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats) @@ -387,6 +390,143 @@ case class CatalogStatistics( } } +/** + * This class of statistics for a column is used in [[CatalogTable]] to interact with metastore. + */ +case class CatalogColumnStat( + distinctCount: Option[BigInt] = None, + min: Option[String] = None, + max: Option[String] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, + histogram: Option[Histogram] = None) { + + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the column and name of the field (e.g. "colName.distinctCount"), + * and the value is the string representation for the value. + * min/max values are stored as Strings. They can be deserialized using + * [[CatalogColumnStat.fromExternalString]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * Any of the fields that are null (None) won't appear in the map. + */ + def toMap(colName: String): Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", "1") + distinctCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString) + } + nullCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString) + } + avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) } + maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) } + min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) } + max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) } + histogram.foreach { h => + map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h)) + } + map.toMap + } + + /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */ + def toPlanStat( + colName: String, + dataType: DataType): ColumnStat = + ColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) +} + +object CatalogColumnStat extends Logging { + + // List of string keys used to serialize CatalogColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" + + /** + * Converts from string representation of data type to the corresponding Catalyst data type. + */ + def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + + /** + * Creates a [[CatalogColumnStat]] object from the given map. + * This is used to deserialize column stats from some external storage. + * The serialization side is defined in [[CatalogColumnStat.toMap]]. + */ + def fromMap( + table: String, + colName: String, + map: Map[String, String]): Option[CatalogColumnStat] = { + + try { + Some(CatalogColumnStat( + distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)), + min = map.get(s"${colName}.${KEY_MIN_VALUE}"), + max = map.get(s"${colName}.${KEY_MAX_VALUE}"), + nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)), + avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong), + maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong), + histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize) + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e) + None + } + } +} + case class CatalogTableType private(name: String) object CatalogTableType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 59cb26d5e6c3..efb2eba655e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -301,6 +302,8 @@ package object dsl { def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) + def filter[T : Encoder](func: FilterFunction[T]): LogicalPlan = TypedFilter(func, logicalPlan) + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f..3340789398f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef..4cc84b27d9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (input.isNullAt(ordinal)) { + if (nullable && input.isNullAt(ordinal)) { null } else { - dataType match { - case BooleanType => input.getBoolean(ordinal) - case ByteType => input.getByte(ordinal) - case ShortType => input.getShort(ordinal) - case IntegerType | DateType => input.getInt(ordinal) - case LongType | TimestampType => input.getLong(ordinal) - case FloatType => input.getFloat(ordinal) - case DoubleType => input.getDouble(ordinal) - case StringType => input.getUTF8String(ordinal) - case BinaryType => input.getBinary(ordinal) - case CalendarIntervalType => input.getInterval(ordinal) - case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => input.getStruct(ordinal, t.size) - case _: ArrayType => input.getArray(ordinal) - case _: MapType => input.getMap(ordinal) - case _ => input.get(ordinal, dataType) - } + accessor(input, ordinal) } } @@ -66,16 +52,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16d..12330bfa55ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,85 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) + case StructType(fields) => + buildCast[InternalRow](_, row => { + val builder = new UTF8StringBuilder + builder.append("[") + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (!row.isNullAt(0)) { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (!row.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) + case pudt: PythonUserDefinedType => castToString(pudt.sqlType) + case udt: UserDefinedType[_] => + buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -590,13 +669,133 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultIsNull = $inputIsNull; - ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") + val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") + val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc($getMapFirstKey)); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc($getMapFirstValue)); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc($getMapKeyArray)); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc($getMapValueArray)); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + + private def writeStructToStringBuilder( + st: Seq[DataType], + row: String, + buffer: String, + ctx: CodegenContext): String = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshName("field") + val fieldStr = ctx.freshName("fieldStr") + s""" + |${if (i != 0) s"""$buffer.append(",");""" else ""} + |if (!$row.isNullAt($i)) { + | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | + | // Append $i field into the string buffer + | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode, + funcName = "fieldToString", + arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + + s""" + |$buffer.append("["); + |$writeStructCode + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +807,47 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } + case StructType(fields) => + (c, evPrim, evNull) => { + val row = ctx.freshName("row") + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + s""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } + case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) + case udt: UserDefinedType[_] => + val udtRef = ctx.addReferenceObj("udt", udt) + (c, evPrim, evNull) => { + s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } @@ -965,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; + ${CodeGenerator.javaType(fromType)} $fromElementPrim = + ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { @@ -1022,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") val toFieldNull = ctx.freshName("tfn") - val fromType = ctx.javaType(from.fields(i).dataType) + val fromType = CodeGenerator.javaType(from.fields(i).dataType) s""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; + ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 456871493309..97dff6ae8829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,16 +120,16 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { + val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = JavaCode.isNullGlobal(globalIsNull) s"$globalIsNull = $localIsNull;" } else { "" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = JavaCode.variable(newValue, dataType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -288,6 +290,7 @@ trait NonSQLExpression extends Expression { final override def sql: String = { transform { case a: Attribute => new PrettyAttribute(a) + case a: Alias => PrettyAttribute(a.sql, a.dataType) }.toString } } @@ -328,6 +331,32 @@ trait Nondeterministic extends Expression { protected def evalInternal(input: InternalRow): Any } +/** + * An expression that contains mutable state. A stateful expression is always non-deterministic + * because the results it produces during evaluation are not only dependent on the given input + * but also on its internal state. + * + * The state of the expressions is generally not exposed in the parameter list and this makes + * comparing stateful expressions problematic because similar stateful expressions (with the same + * parameter list) but with different internal state will be considered equal. This is especially + * problematic during tree transformations. In order to counter this the `fastEquals` method for + * stateful expressions only returns `true` for the same reference. + * + * A stateful expression should never be evaluated multiple times for a single row. This should + * only be a problem for interpreted execution. This can be prevented by creating fresh copies + * of the stateful expression before execution, these can be made using the `freshCopy` function. + */ +trait Stateful extends Nondeterministic { + /** + * Return a fresh uninitialized copy of the stateful expression. + */ + def freshCopy(): Stateful + + /** + * Only the same reference is considered equal. + */ + override def fastEquals(other: TreeNode[_]): Boolean = this eq other +} /** * A leaf expression, i.e. one without any child expressions. @@ -411,15 +440,15 @@ abstract class UnaryExpression extends Expression { ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) } } } @@ -510,7 +539,7 @@ abstract class BinaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { @@ -518,8 +547,8 @@ abstract class BinaryExpression extends Expression { boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) } } } @@ -654,7 +683,7 @@ abstract class TernaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = s""" @@ -662,8 +691,8 @@ abstract class TernaryExpression extends Expression { ${leftGen.code} ${midGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala new file mode 100644 index 000000000000..6d69d69b1c80 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{UserDefinedType, _} +import org.apache.spark.unsafe.Platform + +/** + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer + * should copy the row if it is being buffered. This class is not thread safe. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { + import InterpretedUnsafeProjection._ + + /** Number of (top level) fields in the resulting row. */ + private[this] val numFields = expressions.length + + /** Array that expression results. */ + private[this] val values = new Array[Any](numFields) + + /** The row representing the expression results. */ + private[this] val intermediate = new GenericInternalRow(values) + + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) + + /** The writer that writes the intermediate result to the result row. */ + private[this] val writer: InternalRow => Unit = { + val baseWriter = generateStructWriter( + rowWriter, + expressions.map(e => StructField("", e.dataType, e.nullable))) + if (!expressions.exists(_.nullable)) { + // No nullable fields. The top-level null bit mask will always be zeroed out. + baseWriter + } else { + // Zero out the null bit mask before we write the row. + row => { + rowWriter.zeroOutNullBytes() + baseWriter(row) + } + } + } + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + override def apply(row: InternalRow): UnsafeRow = { + // Put the expression results in the intermediate row. + var i = 0 + while (i < numFields) { + values(i) = expressions(i).eval(row) + i += 1 + } + + // Write the intermediate row to an unsafe row. + rowWriter.reset() + writer(intermediate) + rowWriter.getRow() + } +} + +/** + * Helper functions for creating an [[InterpretedUnsafeProjection]]. + */ +object InterpretedUnsafeProjection extends UnsafeProjectionCreator { + + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedUnsafeProjection(cleanedExpressions.toArray) + } + + /** + * Generate a struct writer function. The generated function writes an [[InternalRow]] to the + * given buffer using the given [[UnsafeRowWriter]]. + */ + private def generateStructWriter( + rowWriter: UnsafeRowWriter, + fields: Array[StructField]): InternalRow => Unit = { + val numFields = fields.length + + // Create field writers. + val fieldWriters = fields.map { field => + generateFieldWriter(rowWriter, field.dataType, field.nullable) + } + // Create basic writer. + row => { + var i = 0 + while (i < numFields) { + fieldWriters(i).apply(row, i) + i += 1 + } + } + } + + /** + * Generate a writer function for a struct field, array element, map key or map value. The + * generated function writes the element at an index in a [[SpecializedGetters]] object (row + * or array) to the given buffer using the given [[UnsafeWriter]]. + */ + private def generateFieldWriter( + writer: UnsafeWriter, + dt: DataType, + nullable: Boolean): (SpecializedGetters, Int) => Unit = { + + // Create the the basic writer. + val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match { + case BooleanType => + (v, i) => writer.write(i, v.getBoolean(i)) + + case ByteType => + (v, i) => writer.write(i, v.getByte(i)) + + case ShortType => + (v, i) => writer.write(i, v.getShort(i)) + + case IntegerType | DateType => + (v, i) => writer.write(i, v.getInt(i)) + + case LongType | TimestampType => + (v, i) => writer.write(i, v.getLong(i)) + + case FloatType => + (v, i) => writer.write(i, v.getFloat(i)) + + case DoubleType => + (v, i) => writer.write(i, v.getDouble(i)) + + case DecimalType.Fixed(precision, scale) => + (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) + + case CalendarIntervalType => + (v, i) => writer.write(i, v.getInterval(i)) + + case BinaryType => + (v, i) => writer.write(i, v.getBinary(i)) + + case StringType => + (v, i) => writer.write(i, v.getUTF8String(i)) + + case StructType(fields) => + val numFields = fields.length + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) + (v, i) => { + v.getStruct(i, fields.length) match { + case row: UnsafeRow => + writer.write(i, row) + case row => + val previousCursor = writer.cursor() + // Nested struct. We don't know where this will start because a row can be + // variable length, so we need to update the offsets and zero out the bit mask. + rowWriter.resetRowWriter() + structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + } + } + + case ArrayType(elementType, containsNull) => + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) + val elementWriter = generateFieldWriter( + arrayWriter, + elementType, + containsNull) + (v, i) => { + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + } + + case MapType(keyType, valueType, valueContainsNull) => + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) + val keyWriter = generateFieldWriter( + keyArrayWriter, + keyType, + nullable = false) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) + val valueWriter = generateFieldWriter( + valueArrayWriter, + valueType, + valueContainsNull) + (v, i) => { + v.getMap(i) match { + case map: UnsafeMapData => + writer.write(i, map) + case map => + val previousCursor = writer.cursor() + + // preserve 8 bytes to write the key array numBytes later. + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) + + // Write the keys and write the numBytes of key array into the first 8 bytes. + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) + + // Write the values. + writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + } + } + + case udt: UserDefinedType[_] => + generateFieldWriter(writer, udt.sqlType, nullable) + + case NullType => + (_, _) => {} + + case _ => + throw new SparkException(s"Unsupported data type $dt") + } + + // Always wrap the writer with a null safe version. + dt match { + case _: UserDefinedType[_] => + // The null wrapper depends on the sql type and not on the UDT. + unsafeWriter + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + // We can't call setNullAt() for DecimalType with precision larger than 18, we call write + // directly. We can use the unwrapped writer directly. + unsafeWriter + case BooleanType | ByteType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull1Bytes(i) + } + } + case ShortType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull2Bytes(i) + } + } + case IntegerType | DateType | FloatType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull4Bytes(i) + } + } + case _ => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull8Bytes(i) + } + } + } + } + + /** + * Get the number of bytes elements of a data type will occupy in the fixed part of an + * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an + * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length + * portion of the array object. Primitives take up to 8 bytes, depending on the size of the + * underlying data type. + */ + private def getElementSize(dataType: DataType): Int = dataType match { + case NullType | StringType | BinaryType | CalendarIntervalType | + _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 + case _ => dataType.defaultSize + } + + /** + * Write an array to the buffer. If the array is already in serialized form (an instance of + * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element + * copy. + */ + private def writeArray( + arrayWriter: UnsafeArrayWriter, + elementWriter: (SpecializedGetters, Int) => Unit, + array: ArrayData): Unit = array match { + case unsafe: UnsafeArrayData => + arrayWriter.write(unsafe) + case _ => + val numElements = array.numElements() + arrayWriter.initialize(numElements) + var i = 0 + while (i < numElements) { + elementWriter.apply(array, i) + i += 1 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 11fb579dfa88..ad1e7bdb3198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. """) -case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -65,18 +65,20 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" + + override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 64b94f0a2c10..3cd73682188b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -object UnsafeProjection { - +trait UnsafeProjectionCreator { /** * Returns an UnsafeProjection for given StructType. * @@ -127,13 +126,13 @@ object UnsafeProjection { } /** - * Returns an UnsafeProjection for given sequence of Expressions (bounded). + * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) - GenerateUnsafeProjection.generate(unsafeExprs) + createProjection(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -146,6 +145,18 @@ object UnsafeProjection { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + protected def createProjection(exprs: Seq[Expression]): UnsafeProjection +} + +object UnsafeProjection extends UnsafeProjectionCreator { + + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(exprs) + } + /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. * TODO: refactor the plumbing and clean this up. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index d3f743d9eb61..efd664dde725 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -15,12 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType +/** + * Helper functions for [[PythonUDF]] + */ +object PythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_UDF + ) + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) + } + + def isGroupAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF + } +} + /** * A serialized version of a Python lambda function. */ @@ -30,12 +49,16 @@ case class PythonUDF( dataType: DataType, children: Seq[Expression], evalType: Int, - udfDeterministic: Boolean) + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"$name(${children.mkString(", ")})" + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 388ef42883ad..e869258469a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,6 +49,17 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + // The constructor for SPARK 2.1 and 2.2 + def this( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType], + udfName: Option[String]) = { + this( + function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + } + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = @@ -1007,11 +1018,12 @@ case class ScalaUDF( val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" val resultConverter = s"$convertersTerm[${children.length}]" + val boxedType = CodeGenerator.boxedType(dataType) val callFunc = s""" - |${ctx.boxedType(dataType)} $resultTerm = null; + |$boxedType $resultTerm = null; |try { - | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); + | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); |} catch (Exception e) { | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} @@ -1024,7 +1036,7 @@ case class ScalaUDF( |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index a160b9b27529..787bcaf5e81d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9a9f579b37f5..6c4a3601c173 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,7 +165,7 @@ case class PreciseTimestampConversion( val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; - |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed..f1bbbdabb41f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -296,8 +275,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +291,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,13 +309,13 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 19abce01a26c..e1d16a2cd38b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -190,17 +190,15 @@ abstract class AggregateFunction extends Expression { def defaultResult: Option[Literal] = None /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because - * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * and the flag indicating if this aggregation is distinct aggregation or not. - * An [[AggregateFunction]] should not be used without being wrapped in - * an [[AggregateExpression]]. + * Creates [[AggregateExpression]] with `isDistinct` flag disabled. + * + * @see `toAggregateExpression(isDistinct: Boolean)` for detailed description */ def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct - * field of the [[AggregateExpression]] to the given value because + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct` + * flag of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, * and the flag indicating if this aggregation is distinct aggregation or not. * An [[AggregateFunction]] should not be used without being wrapped in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6d7..d4e322d23b95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -43,16 +43,16 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code s""" - ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); + ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); + ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) - case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { @@ -104,10 +104,10 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -117,19 +117,25 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType - override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") + /** Name of the function for this expression on a [[CalendarInterval]] type. */ + def calendarIntervalMethod: String = + sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case CalendarIntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" + + override def calendarIntervalMethod: String = "add" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" + override def decimalMethod: String = "$minus" + + override def calendarIntervalMethod: String = "subtract" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -278,7 +268,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -288,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -299,7 +289,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -365,7 +355,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -375,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -386,7 +376,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "pmod" - protected def checkTypesInternal(t: DataType) = + protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeUtils.checkForNumericExpr(t, "pmod") override def inputType: AbstractDataType = NumericType @@ -454,13 +444,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s"${eval2.value} == 0" } val remainder = ctx.freshName("remainder") - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val result = dataType match { case DecimalType.Fixed(_, _) => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value}); + $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); } else { @@ -470,17 +460,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $remainder = - (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value}); + $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { - ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value}); + ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); } else { ${ev.value}=$remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value}; + $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; } else { @@ -493,7 +482,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -504,7 +493,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -602,19 +591,15 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfSmaller(dataType, ev, eval)} """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", @@ -629,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } @@ -681,19 +666,15 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", @@ -708,7 +689,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 173481f06a71..cc24e397cc14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2c714c228e6c..4dda52529425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -56,7 +57,21 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) + +object ExprCode { + def apply(isNull: ExprValue, value: ExprValue): ExprCode = { + ExprCode(code = "", isNull, value) + } + + def forNullValue(dataType: DataType): ExprCode = { + ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + } + + def forNonNullValue(value: ExprValue): ExprCode = { + ExprCode(code = "", isNull = FalseLiteral, value = value) + } +} /** * State used for subexpression elimination. @@ -66,7 +81,7 @@ case class ExprCode(var code: String, var isNull: String, var value: String) * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -99,6 +114,8 @@ private[codegen] case class NewFunctionSpec( */ class CodegenContext { + import CodeGenerator._ + /** * Holding a list of objects that could be used passed into generated class. */ @@ -190,11 +207,11 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { - if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" currentIndex += 1 res @@ -241,10 +258,10 @@ class CodegenContext { * are satisfied: * 1. forceInline is true * 2. its type is primitive type and the total number of the inlined mutable variables - * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * is less than `OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array * When a variable is compacted into an array, the max size of the array for compaction - * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -255,7 +272,7 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) @@ -317,7 +334,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } def declareMutableStates(): String = { @@ -333,7 +350,7 @@ class CodegenContext { val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { - CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + MUTABLESTATEARRAY_SIZE_LIMIT } if (javaType.contains("[]")) { // initializer had an one-dimensional array variable @@ -389,7 +406,7 @@ class CodegenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -462,7 +479,7 @@ class CodegenContext { inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { + } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -531,14 +548,6 @@ class CodegenContext { extraClasses.append(code) } - final val JAVA_BOOLEAN = "boolean" - final val JAVA_BYTE = "byte" - final val JAVA_SHORT = "short" - final val JAVA_INT = "int" - final val JAVA_LONG = "long" - final val JAVA_FLOAT = "float" - final val JAVA_DOUBLE = "double" - /** * The map from a variable name to it's next ID. */ @@ -564,217 +573,20 @@ class CodegenContext { } else { s"${freshNamePrefix}_$name" } - if (freshNameIds.contains(fullName)) { - val id = freshNameIds(fullName) - freshNameIds(fullName) = id + 1 - s"$fullName$id" - } else { - freshNameIds += fullName -> 1 - fullName - } - } - - /** - * Returns the specialized code to access a value from `inputRow` at `ordinal`. - */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" - } - } - - /** - * Returns the code to update a column in Row for a given DataType. - */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) - // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy - // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" - } - } - - /** - * Update a column in MutableRow from ExprCode. - * - * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise - */ - def updateColumn( - row: String, - dataType: DataType, - ordinal: Int, - ev: ExprCode, - nullable: Boolean, - isVectorized: Boolean = false): String = { - if (nullable) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - ${setColumn(row, dataType, ordinal, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - $row.setNullAt($ordinal); - } - """ - } - } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType`. - */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType` - * that could potentially be nullable. - */ - def updateColumn( - vector: String, - rowId: String, - dataType: DataType, - ev: ExprCode, - nullable: Boolean): String = { - if (nullable) { - s""" - if (!${ev.isNull}) { - ${setValue(vector, rowId, dataType, ev.value)} - } else { - $vector.putNull($rowId); - } - """ - } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" - } - } - - /** - * Returns the specialized code to access a value from a column vector for a given `DataType`. - */ - def getValue(vector: String, rowId: String, dataType: DataType): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.get${primitiveTypeName(jt)}($rowId)" - case t: DecimalType => - s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" - case StringType => - s"$vector.getUTF8String($rowId)" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the name used in accessor and setter for a Java primitive type. - */ - def primitiveTypeName(jt: String): String = jt match { - case JAVA_INT => "Int" - case _ => boxedType(jt) - } - - def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) - - /** - * Returns the Java type for a DataType. - */ - def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case dt: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" - case ObjectType(cls) => cls.getName - case _ => "Object" - } - - /** - * Returns the boxed type in Java. - */ - def boxedType(jt: String): String = jt match { - case JAVA_BOOLEAN => "Boolean" - case JAVA_BYTE => "Byte" - case JAVA_SHORT => "Short" - case JAVA_INT => "Integer" - case JAVA_LONG => "Long" - case JAVA_FLOAT => "Float" - case JAVA_DOUBLE => "Double" - case other => other - } - - def boxedType(dt: DataType): String = boxedType(javaType(dt)) - - /** - * Returns the representation of default value for a given Java Type. - */ - def defaultValue(jt: String): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => "null" + val id = freshNameIds.getOrElse(fullName, 0) + freshNameIds(fullName) = id + 1 + s"${fullName}_$id" } - def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - /** * Generates code for equal expression in Java. */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" - case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" - case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" + case FloatType => + s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" + case DoubleType => + s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" @@ -810,6 +622,7 @@ class CodegenContext { val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") + val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -831,8 +644,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; - ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue("a", elementType, "i")}; + $jt $elementB = ${getValue("b", elementType, "i")}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -884,6 +697,73 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is smaller than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, partialResult.value, item.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. @@ -904,19 +784,6 @@ class CodegenContext { } } - /** - * List of java data types that have special accessors and setters in [[InternalRow]]. - */ - val primitiveTypes = - Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - - /** - * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. - */ - def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - - def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) - /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1087,7 +954,7 @@ class CodegenContext { // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse - if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) { // Adding a new function to each inner class which contains the invocation of all the // ones which have been added to that inner class. For example, // private class NestedClass { @@ -1122,14 +989,12 @@ class CodegenContext { newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs.clear - newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = newSubExprEliminationExprs val genCodes = f // Restore previous subExprEliminationExprs - subExprEliminationExprs.clear - oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = oldsubExprEliminationExprs genCodes } @@ -1143,7 +1008,7 @@ class CodegenContext { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree) @@ -1156,10 +1021,10 @@ class CodegenContext { // Generate the code for this expression tree. val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) - e.foreach(subExprEliminationExprs.put(_, state)) + e.foreach(localSubExprEliminationExprs.put(_, state)) eval.code.trim } - SubExprCodes(codes, subExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap) } /** @@ -1206,8 +1071,10 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) - e.foreach(subExprEliminationExprs.put(_, state)) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType)) + subExprEliminationExprs ++= e.map(_ -> state).toMap } } @@ -1230,14 +1097,29 @@ class CodegenContext { /** * Register a comment and return the corresponding place holder + * + * @param placeholderId an optionally specified identifier for the comment's placeholder. + * The caller should make sure this identifier is unique within the + * compilation unit. If this argument is not specified, a fresh identifier + * will be automatically created and used as the placeholder. + * @param force whether to force registering the comments */ - def registerComment(text: => String): String = { + def registerComment( + text: => String, + placeholderId: String = "", + force: Boolean = false): String = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this // flat, see SPARK-15680. - if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { - val name = freshName("c") + if (force || + SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = if (placeholderId != "") { + assert(!placeHolderToComments.contains(placeholderId)) + placeholderId + } else { + freshName("c") + } val comment = if (text.contains("\n") || text.contains("\r")) { text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") } else { @@ -1315,26 +1197,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings - val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + + // The max valid length of method parameters in JVM. + final val MAX_JVM_METHOD_PARAMS_LENGTH = 255 // This is the threshold over which the methods in an inner class are grouped in a single // method which is going to be called by the outer class instead of the many small ones - val MERGE_SPLIT_METHODS_THRESHOLD = 3 + final val MERGE_SPLIT_METHODS_THRESHOLD = 3 // The number of named constants that can exist in the class is limited by the Constant Pool // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a // threshold of 1000k bytes to determine when a function should be inlined to a private, inner // class. - val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 // This is the threshold for the number of global variables, whose types are primitive type or // complex type (e.g. more than one-dimensional array), that will be placed at the outer class - val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 // This is the maximum number of array elements to keep global variables in one Java array // 32767 is the maximum integer value that does not require a constant pool entry in a Java // bytecode instruction - val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 /** * Compile the Java source code into a Java class, using Janino. @@ -1479,4 +1364,267 @@ object CodeGenerator extends Logging { result } }) + + /** + * Name of Java primitive data type + */ + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + + /** + * List of java primitive data types + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) + + /** + * Returns true if a Java type is Java primitive primitive type + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Returns the specialized code to access a value from `inputRow` at `ordinal`. + */ + def getValue(input: String, dataType: DataType, ordinal: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" + case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" + } + } + + /** + * Returns the code to update a column in Row for a given DataType. + */ + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" + case _ => s"$row.update($ordinal, $value)" + } + } + + /** + * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean, + isVectorized: Boolean = false): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | ${setColumn(row, dataType, ordinal, "null")}; + |} + """.stripMargin + } else { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | $row.setNullAt($ordinal); + |} + """.stripMargin + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType`. + */ + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _ => + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType` + * that could potentially be nullable. + */ + def updateColumn( + vector: String, + rowId: String, + dataType: DataType, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + s""" + |if (!${ev.isNull}) { + | ${setValue(vector, rowId, dataType, ev.value)} + |} else { + | $vector.putNull($rowId); + |} + """.stripMargin + } else { + s"""${setValue(vector, rowId, dataType, ev.value)};""" + } + } + + /** + * Returns the specialized code to access a value from a column vector for a given `DataType`. + */ + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) + } + } + + /** + * Returns the name used in accessor and setter for a Java primitive type. + */ + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } + + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) + + /** + * Returns the Java type for a DataType. + */ + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case _: DecimalType => "Decimal" + case BinaryType => "byte[]" + case StringType => "UTF8String" + case CalendarIntervalType => "CalendarInterval" + case _: StructType => "InternalRow" + case _: ArrayType => "ArrayData" + case _: MapType => "MapData" + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName + case _ => "Object" + } + + def javaClass(dt: DataType): Class[_] = dt match { + case BooleanType => java.lang.Boolean.TYPE + case ByteType => java.lang.Byte.TYPE + case ShortType => java.lang.Short.TYPE + case IntegerType | DateType => java.lang.Integer.TYPE + case LongType | TimestampType => java.lang.Long.TYPE + case FloatType => java.lang.Float.TYPE + case DoubleType => java.lang.Double.TYPE + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case udt: UserDefinedType[_] => javaClass(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[Object] + } + + /** + * Returns the boxed type in Java. + */ + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + + /** + * Returns the representation of default value for a given Java Type. + * @param jt the string name of the Java type + * @param typedNull if true, for null literals, return a typed (with a cast) version + */ + def defaultValue(jt: String, typedNull: Boolean): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + + def defaultValue(dt: DataType, typedNull: Boolean = false): String = + defaultValue(javaType(dt), typedNull) + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + val javaParamLength = javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaParamLength + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a9f..a91989e12966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -44,21 +44,22 @@ trait CodegenFallback extends Expression { } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) + val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7e2..33d14329ec95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() - val (validExpr, index) = expressions.zipWithIndex.filter { + val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true - }.unzip - val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + } + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { - case (ev, i) => - val e = expressions(i) - val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") - if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { + case ((e, i), ev) => + val value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"), + e.dataType) + val (code, isNull) = if (e.nullable) { + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, JavaCode.isNullGlobal(isNull)) } else { (s""" |${ev.code} |$value = ${ev.value}; - """.stripMargin, ev.isNull, value, i) + """.stripMargin, FalseLiteral) } + val update = CodeGenerator.updateColumn( + "mutableRow", + e.dataType, + i, + ExprCode(isNull, value), + e.nullable) + (code, update) } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(projectionCodes).map { - case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) - ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) - } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) - val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed63..9a51be6ed5ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR s""" ${ctx.INPUT_ROW} = a; boolean $isNullA; - ${ctx.javaType(order.child.dataType)} $primitiveA; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; @@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } ${ctx.INPUT_ROW} = b; boolean $isNullB; - ${ctx.javaType(order.child.dataType)} $primitiveB; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba42..01c350e9dbf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -53,7 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), + dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow])) } private def createCodeForArray( @@ -90,7 +94,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), + elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData])) } private def createCodeForMap( @@ -125,19 +131,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData])) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode(FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { @@ -153,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] mutableRow.setNullAt($i); } else { ${converter.code} - ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b..01b4d6c4529b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,14 +32,13 @@ import org.apache.spark.sql.types._ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { /** Returns true iff we support this data type. */ - def canSupport(dataType: DataType): Boolean = dataType match { + def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true - case t: AtomicType => true + case _: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -47,22 +46,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeStructToBuffer( ctx: CodegenContext, input: String, + index: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode( + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + val previousCursor = ctx.freshName("previousCursor") s""" - final InternalRow $tmpInput = $input; - if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} - } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} - } - """ + |final InternalRow $tmpInput = $input; + |if ($tmpInput instanceof UnsafeRow) { + | $rowWriter.write($index, (UnsafeRow) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin } private def writeExpressionsToBuffer( @@ -70,12 +80,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,16 +94,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => - val dt = dataType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } - val tmpCursor = ctx.freshName("tmpCursor") + val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -106,56 +108,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$rowWriter.setNullAt($index);" } - val writeField = dt match { - case t: StructType => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - """ - - case a @ ArrayType(et, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - """ - - case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$rowWriter.write($index, ${input.value});" - } - - if (input.isNull == "false") { + val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) + if (input.isNull == FalseLiteral) { s""" - ${input.code} - ${writeField.trim} - """ + |${input.code} + |${writeField.trim} + """.stripMargin } else { s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { - ${writeField.trim} - } - """ + |${input.code} + |if (${input.isNull}) { + | ${setNull.trim} + |} else { + | ${writeField.trim} + |} + """.stripMargin } } @@ -169,11 +136,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro funcName = "writeFields", arguments = Seq("InternalRow" -> row)) } - s""" - $resetWriter - $writeFieldsCode - """.trim + |$resetWriter + |$writeFieldsCode + """.stripMargin } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -181,126 +147,119 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val et = elementType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val et = UserDefinedType.sqlType(elementType) - val jt = ctx.javaType(et) + val jt = CodeGenerator.javaType(et) val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize case _ => 8 // we need 8 bytes to store offset and length } - val tmpCursor = ctx.freshName("tmpCursor") - val element = ctx.getValue(tmpInput, et, index) - val writeElement = et match { - case t: StructType => - s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - """ - - case a @ ArrayType(et, _) => - s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - """ - - case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$arrayWriter.write($index, $element);" - } + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + + val element = CodeGenerator.getValue(tmpInput, et, index) - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" s""" - final ArrayData $tmpInput = $input; - if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} - } else { - final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); - - for (int $index = 0; $index < $numElements; $index++) { - if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); - } else { - $writeElement - } - } - } - """ + |final ArrayData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeArrayData) { + | $rowWriter.write((UnsafeArrayData) $tmpInput); + |} else { + | final int $numElements = $tmpInput.numElements(); + | $arrayWriter.initialize($numElements); + | + | for (int $index = 0; $index < $numElements; $index++) { + | if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + | } else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + | } + | } + |} + """.stripMargin } // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, + index: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") + val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final MapData $tmpInput = $input; - if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} - } else { - // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + |final MapData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeMapData) { + | $rowWriter.write($index, (UnsafeMapData) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | + | // preserve 8 bytes to write the key array numBytes later. + | $rowWriter.grow(8); + | $rowWriter.increaseCursor(8); + | + | // Remember the current cursor so that we can write numBytes of key array later. + | final int $tmpCursor = $rowWriter.cursor(); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | + | // Write the numBytes of key array into the first 8 bytes. + | Platform.putLong( + | $rowWriter.getBuffer(), + | $tmpCursor - 8, + | $rowWriter.cursor() - $tmpCursor); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin + } + + private def writeElement( + ctx: CodegenContext, + input: String, + index: String, + dt: DataType, + writer: String): String = dt match { + case t: StructType => + writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + + case ArrayType(et, _) => + val previousCursor = ctx.freshName("previousCursor") + s""" + |// Remember the current cursor so that we can calculate how many bytes are + |// written later. + |final int $previousCursor = $writer.cursor(); + |${writeArrayToBuffer(ctx, input, et, writer)} + |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + """.stripMargin - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + case MapType(kt, vt, _) => + writeMapToBuffer(ctx, input, index, kt, vt, writer) - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} - // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + case DecimalType.Fixed(precision, scale) => + s"$writer.write($index, $input, $precision, $scale);" - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} - } - """ - } + case NullType => "" - /** - * If the input is already in unsafe format, we don't need to go through all elements/fields, - * we can directly write it. - */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { - val sizeInBytes = ctx.freshName("sizeInBytes") - s""" - final int $sizeInBytes = $input.getSizeInBytes(); - // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; - """ + case _ => s"$writer.write($index, $input);" } def createCode( @@ -316,38 +275,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") - - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$holder.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$result.setTotalSize($holder.totalSize());" - } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" - $resetBufferHolder - $evalSubexpr - $writeExpressions - $updateRowSize - """ - ExprCode(code, "false", result) + |$rowWriter.reset(); + |$evalSubexpr + |$writeExpressions + """.stripMargin + // `rowWriter` is declared as a class field, so we can access it directly in methods. + ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow])) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -372,38 +317,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new SpecificUnsafeProjection(references); - } - - class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - - private Object[] references; - ${ctx.declareMutableStates()} - - public SpecificUnsafeProjection(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public void initialize(int partitionIndex) { - ${ctx.initPartition()} - } - - // Scala.Function1 need this - public java.lang.Object apply(java.lang.Object row) { - return apply((InternalRow) row); - } - - public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code.trim} - return ${eval.value}; - } - - ${ctx.declareAddedFunctions()} - } - """ + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificUnsafeProjection(references); + |} + | + |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { + | + | private Object[] references; + | ${ctx.declareMutableStates()} + | + | public SpecificUnsafeProjection(Object[] references) { + | this.references = references; + | ${ctx.initMutableStates()} + | } + | + | public void initialize(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | // Scala.Function1 need this + | public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + | } + | + | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | return ${eval.value}; + | } + | + | ${ctx.declareAddedFunctions()} + |} + """.stripMargin val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index be5f5a73b5d4..febf7b0c96c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => - val bits = if (bitset1Remainder > 0) { + val bits = if (bitset1Remainder > 0 && bitset2Words != 0) { if (i < bitset1Words - 1) { s"$getLong(obj1, offset1 + ${i * 8})" } else if (i == bitset1Words - 1) { @@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } else { // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the // offset in the upper 32 bit of the words, we can just shift the offset to the left by - // 32 and increment that amount in place. + // 32 and increment that amount in place. However, we need to handle the important special + // case of a null field, in which case the offset should be zero and should not have a + // shift added to it. val shift = if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" @@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" + // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's + // output as a de-facto specification for the internal layout of data. + // + // Null-valued fields will always have a data offset of 0 because + // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's + // position in the fixed-length section of the row. As a result, we must NOT add + // `shift` to the offset for null fields. + // + // We could perform a null-check here by inspecting the null-tracking bitmap, but doing + // so could be expensive and will add significant bloat to the generated code. Instead, + // we'll rely on the invariant "stored offset == 0 for variable-length data type implies + // that the field's value is null." + // + // To establish that this invariant holds, we'll prove that a non-null field can never + // have a stored offset of 0. There are two cases to consider: + // + // 1. The non-null field's data is of non-zero length: reading this field's value + // must read data from the variable-length section of the row, so the stored offset + // will actually be used in address calculation and must be correct. The offsets + // count bytes from the start of the UnsafeRow so these offsets will always be + // non-zero because the storage of the offsets themselves takes up space at the + // start of the row. + // 2. The non-null field's data is of zero length (i.e. its data is empty). In this + // case, we have to worry about the possibility that an arbitrary offset value was + // stored because we never actually read any bytes using this offset and therefore + // would not crash if it was incorrect. The variable-sized data writing paths in + // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with + // no special handling for the case where `numBytes == 0`. Internally, + // setOffsetAndSize computes the offset without taking the size into account. Thus + // the stored offset is the same non-zero offset that would be used if the field's + // dataSize was non-zero (and in (1) above we've shown that case behaves as we + // expect). + // + // Thus it is safe to perform `existingOffset != 0` checks here in the place of + // more expensive null-bit checks. + s""" + |existingOffset = $getLong(buf, $cursor); + |if (existingOffset != 0) { + | $putLong(buf, $cursor, existingOffset + ($shift << 32)); + |} + """.stripMargin } } val updateOffsets = ctx.splitExpressions( expressions = updateOffset, funcName = "copyBitsetFunc", - arguments = ("long", "numBytesVariableRow1") :: Nil) + arguments = ("long", "numBytesVariableRow1") :: Nil, + makeSplitFunction = (s: String) => "long existingOffset;\n" + s) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 + | long existingOffset; | $updateOffsets | | out.pointTo(buf, sizeInBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala new file mode 100644 index 000000000000..74ff01848886 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import java.lang.{Boolean => JBool} + +import scala.language.{existentials, implicitConversions} + +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Trait representing an opaque fragments of java code. + */ +trait JavaCode { + def code: String + override def toString: String = code +} + +/** + * Utility functions for creating [[JavaCode]] fragments. + */ +object JavaCode { + /** + * Create a java literal. + */ + def literal(v: String, dataType: DataType): LiteralValue = dataType match { + case BooleanType if v == "true" => TrueLiteral + case BooleanType if v == "false" => FalseLiteral + case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a default literal. This is null for reference types, false for boolean types and + * -1 for other primitive types. + */ + def defaultLiteral(dataType: DataType): LiteralValue = { + new LiteralValue( + CodeGenerator.defaultValue(dataType, typedNull = true), + CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, dataType: DataType): VariableValue = { + variable(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, javaClass: Class[_]): VariableValue = { + VariableValue(name, javaClass) + } + + /** + * Create a local isNull variable. + */ + def isNullVariable(name: String): VariableValue = variable(name, BooleanType) + + /** + * Create a global java variable. + */ + def global(name: String, dataType: DataType): GlobalValue = { + global(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a global java variable. + */ + def global(name: String, javaClass: Class[_]): GlobalValue = { + GlobalValue(name, javaClass) + } + + /** + * Create a global isNull variable. + */ + def isNullGlobal(name: String): GlobalValue = global(name, BooleanType) + + /** + * Create an expression fragment. + */ + def expression(code: String, dataType: DataType): SimpleExprValue = { + expression(code, CodeGenerator.javaClass(dataType)) + } + + /** + * Create an expression fragment. + */ + def expression(code: String, javaClass: Class[_]): SimpleExprValue = { + SimpleExprValue(code, javaClass) + } + + /** + * Create a isNull expression fragment. + */ + def isNullExpression(code: String): SimpleExprValue = { + expression(code, BooleanType) + } +} + +/** + * A typed java fragment that must be a valid java expression. + */ +trait ExprValue extends JavaCode { + def javaType: Class[_] + def isPrimitive: Boolean = javaType.isPrimitive +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + + +/** + * A java expression fragment. + */ +case class SimpleExprValue(expr: String, javaType: Class[_]) extends ExprValue { + override def code: String = s"($expr)" +} + +/** + * A local variable java expression. + */ +case class VariableValue(variableName: String, javaType: Class[_]) extends ExprValue { + override def code: String = variableName +} + +/** + * A global variable java expression. + */ +case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { + override def code: String = value +} + +/** + * A literal java expression. + */ +class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { + override def code: String = value + + override def equals(arg: Any): Boolean = arg match { + case l: LiteralValue => l.javaType == javaType && l.value == value + case _ => false + } + + override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() +} + +case object TrueLiteral extends LiteralValue("true", JBool.TYPE) +case object FalseLiteral extends LiteralValue("false", JBool.TYPE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de..12b9ab2b272a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,9 +20,12 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** * Given an array or map, returns its size. Returns -1 if null. @@ -54,8 +57,8 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } @@ -116,47 +119,16 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Common base class for [[SortArray]] and [[ArraySort]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } + protected def nullOrder: NullOrder @transient private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -167,9 +139,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - -1 + nullOrder } else if (o2 == null) { - 1 + -nullOrder } else { ordering.compare(o1, o2) } @@ -179,7 +151,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -190,9 +162,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - 1 + -nullOrder } else if (o2 == null) { - -1 + nullOrder } else { -ordering.compare(o1, o2) } @@ -200,18 +172,287 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } } - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType + def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + java.util.Arrays.sort(data, if (ascending) lt else gt) } new GenericArrayData(data.asInstanceOf[Array[Any]]) } + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } + s""" + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); + |} + """.stripMargin + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 + } +} + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + override def prettyName: String = "sort_array" } + +/** + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + /** * Checks if the array (left) has the element (right) */ @@ -270,7 +511,7 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { @@ -287,3 +528,943 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + s""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + s""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + +} + +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: Nil) + ev.copy(s""" + $initCode + $codes + $javaType ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (code, numElements) + } + + private def nullArgumentProtection() : String = { + if (nullable) { + s""" + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} + """.stripMargin + } else { + "" + } + } + + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if (!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3dc2ee03a86e..67876a856548 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = JavaCode.variable(arrayData, dataType), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -90,7 +90,7 @@ private [sql] object GenArrayData { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length - if (!ctx.isPrimitiveType(elementType)) { + if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName @@ -111,7 +111,7 @@ private [sql] object GenArrayData { val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", - extraArguments = ("Object[]", arrayDataName) :: Nil) + extraArguments = ("Object[]", arrayName) :: Nil) (s"Object[] $arrayName = new Object[$numElements];", assignmentString, @@ -124,7 +124,7 @@ private [sql] object GenArrayData { ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca390890..3fba52d74545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; } """ } else { s""" - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } }) @@ -205,7 +205,7 @@ case class GetArrayStructFields( } else { final InternalRow $row = $eval.getStruct($j, $numFields); $nullSafeEval { - $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; } } } @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ }) @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,18 +296,20 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -336,7 +319,7 @@ case class GetMapValue(child: Expression, key: Expression) int $index = 0; boolean $found = false; while ($index < $length && !$found) { - final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)}; if (${ctx.genEqual(keyType, key, eval2)}) { $found = true; } else { @@ -347,9 +330,43 @@ case class GetMapValue(child: Expression, key: Expression) if (!$found$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(values, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b444c3a7be92..205d77f6a9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" |${condEval.code} |boolean ${ev.isNull} = false; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,9 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + dataType) // these blocks are meant to be inside a // do { @@ -244,10 +246,10 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BYTE, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -264,7 +266,7 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7a674ea7f4d7..76aa61415a11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,8 +23,10 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -424,36 +426,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa """, since = "2.3.0") // scalastyle:on line.size.limit -case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) +case class DayOfWeek(child: Expression) extends DayWeek { - override def dataType: DataType = IntegerType + override protected def nullSafeEval(date: Any): Any = { + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + cal.get(Calendar.DAY_OF_WEEK) + } - @transient private lazy val c = { - Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).", + examples = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 3 + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class WeekDay(child: Expression) extends DayWeek { override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.DAY_OF_WEEK) + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7 } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" + val c = "calWeekDay" ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ }) } } +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient protected lazy val cal: Calendar = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", @@ -671,18 +708,19 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode.forNullValue(dataType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; @@ -711,7 +749,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; }""") @@ -722,7 +760,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") @@ -810,14 +848,14 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode.forNullValue(StringType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -978,6 +1016,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = s""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield @@ -1008,7 +1088,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1017,8 +1097,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") @@ -1117,38 +1198,49 @@ case class AddMonths(startDate: Expression, numMonths: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. + The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } @@ -1185,7 +1277,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1194,8 +1286,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") @@ -1340,18 +1433,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.$truncFuncStr; }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4f4d49166e88..3af4bfebad45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da5..ef790338bdd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -269,7 +270,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -278,7 +279,7 @@ abstract class HashExpression[E] extends Expression { } } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", @@ -307,9 +308,10 @@ abstract class HashExpression[E] extends Expression { ctx: CodegenContext): String = { val element = ctx.freshName("element") + val jt = CodeGenerator.javaType(elementType) ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } @@ -359,10 +361,7 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + s"$result = $hasherClassName.hashUTF8String($input, $result);" } protected def genHashForMap( @@ -407,7 +406,7 @@ abstract class HashExpression[E] extends Expression { val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", @@ -464,6 +463,8 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long + /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -489,8 +490,7 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) case array: ArrayData => val elementType = dataType match { @@ -577,9 +577,15 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } + + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) + } } /** @@ -604,9 +610,14 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } + + override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { + XXH64.hashUnsafeBytesBlock(base, seed) + } } /** @@ -632,7 +643,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => @@ -651,11 +662,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, + extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$body |return ${ev.value}; """.stripMargin, @@ -664,8 +675,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = s""" - |${ctx.JAVA_INT} ${ev.value} = $seed; - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes """.stripMargin) } @@ -713,10 +724,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + s"$result = $hasherClassName.hashUTF8String($input);" } override protected def genHashForArray( @@ -780,14 +788,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), - returnType = ctx.JAVA_INT, + arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childResult = 0; + |${CodeGenerator.JAVA_INT} $childResult = 0; |$body |return $result; """.stripMargin, @@ -804,10 +812,14 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) + private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed175..2a3cc580273e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +42,8 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -65,8 +65,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b4fed59744..34161f0f03f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -513,12 +514,16 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String], + forceNullableSchema: Boolean) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - override def nullable: Boolean = true - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, None) + // The JSON input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder + // can generate incorrect files if values are missing in columns declared as non-nullable. + val nullableSchema = if (forceNullableSchema) schema.asNullable else schema + + override def nullable: Boolean = true // Used in `FunctionRegistry` def this(child: Expression, schema: Expression) = @@ -526,31 +531,38 @@ case class JsonToStructs( schema = JsonExprUtils.validateSchemaLiteral(schema), options = Map.empty[String, String], child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.validateSchemaLiteral(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + + // Used in `org.apache.spark.sql.functions` + def this(schema: DataType, options: Map[String, String], child: Expression) = + this(schema, options, child, timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) - override def checkInputDataTypes(): TypeCheckResult = schema match { + override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") } @transient - lazy val rowSchema = schema match { + lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st } // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = schema match { + lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => @@ -563,7 +575,7 @@ case class JsonToStructs( rowSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) - override def dataType: DataType = schema + override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a20983..246025b82d59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -58,7 +58,7 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) @@ -277,41 +277,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) - // change the isNull and primitive to consts, to inline them + val javaType = CodeGenerator.javaType(dataType) if (value == null) { - ev.isNull = "true" - ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") + ExprCode.forNullValue(dataType) } else { - ev.isNull = "false" + def toExprCode(code: String): ExprCode = { + ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) + } dataType match { case BooleanType | IntegerType | DateType => - ev.copy(code = "", value = value.toString) + toExprCode(value.toString) case FloatType => - val v = value.asInstanceOf[Float] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}f") + value.asInstanceOf[Float] match { + case v if v.isNaN => + toExprCode("Float.NaN") + case Float.PositiveInfinity => + toExprCode("Float.POSITIVE_INFINITY") + case Float.NegativeInfinity => + toExprCode("Float.NEGATIVE_INFINITY") + case _ => + toExprCode(s"${value}F") } case DoubleType => - val v = value.asInstanceOf[Double] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}D") + value.asInstanceOf[Double] match { + case v if v.isNaN => + toExprCode("Double.NaN") + case Double.PositiveInfinity => + toExprCode("Double.POSITIVE_INFINITY") + case Double.NegativeInfinity => + toExprCode("Double.NEGATIVE_INFINITY") + case _ => + toExprCode(s"${value}D") } case ByteType | ShortType => - ev.copy(code = "", value = s"($javaType)$value") + ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) case TimestampType | LongType => - ev.copy(code = "", value = s"${value}L") + toExprCode(s"${value}L") case _ => - ev.copy(code = "", value = ctx.addReferenceObj("literal", value, - ctx.javaType(dataType))) + val constRef = ctx.addReferenceObj("literal", value, javaType) + ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d8dc0862f114..bc4cfcec4742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -168,9 +168,11 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -178,12 +180,13 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`, + as if computed by `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -191,18 +194,18 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).", + usage = """ + _FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by + `java.lang.Math._FUNC_` + """, examples = """ Examples: > SELECT _FUNC_(0); 0.0 """) -// scalastyle:on line.size.limit case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") @ExpressionDescription( @@ -252,7 +255,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -261,7 +271,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -512,7 +529,11 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sine of `expr`.", + usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -521,7 +542,13 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.", + usage = """ + _FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -539,7 +566,13 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -548,7 +581,13 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cotangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -562,7 +601,14 @@ case class Cot(child: Expression) } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -572,6 +618,10 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH" @ExpressionDescription( usage = "_FUNC_(expr) - Converts radians to degrees.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(3.141592653589793); @@ -583,6 +633,10 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre @ExpressionDescription( usage = "_FUNC_(expr) - Converts degrees to radians.", + arguments = """ + Arguments: + * expr - angle in degrees + """, examples = """ Examples: > SELECT _FUNC_(180); @@ -768,15 +822,22 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).", + usage = """ + _FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane + and the point given by the coordinates (`exprX`, `exprY`), as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * exprY - coordinate on y-axis + * exprX - coordinate on x-axis + """, examples = """ Examples: > SELECT _FUNC_(0, 0); 0.0 """) -// scalastyle:on line.size.limit case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { @@ -1128,15 +1189,16 @@ abstract class RoundBase(child: Expression, scale: Expression, }""" } + val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b42..7eda65a86702 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +32,12 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType - protected override def nullSafeEval(input: Any): Any = input + protected override def nullSafeEval(input: Any): Any = { + // scalastyle:off println + System.err.println(outputPrefix + input) + // scalastyle:on println + input + } private val outputPrefix = s"Result of ${child.simpleString} is " @@ -85,7 +91,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = TrueLiteral, + value = JavaCode.defaultLiteral(dataType)) } override def sql: String = s"assert_true(${child.sql})" @@ -117,18 +124,35 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid() extends LeafExpression { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { + + def this() = this(None) - override lazy val deterministic: Boolean = false + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = + randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = + randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + val randomGen = ctx.freshName("randomGen") + ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, + forceInline = true, + useFreshName = false) + ctx.addPartitionInitializationStatement(s"$randomGen = " + + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + + s"${randomSeed.get}L + partitionIndex);") + ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + isNull = FalseLiteral) } + + override def freshCopy(): Uuid = Uuid(randomSeed) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 470d5da041ea..0787342bce6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", returnType = resultType, makeSplitFunction = func => s""" - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $codes |} while (false); @@ -234,8 +234,8 @@ case class IsNaN(child: Expression) extends UnaryExpression case DoubleType | FloatType => ev.copy(code = s""" ${eval.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } } @@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression) ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -320,7 +320,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -346,7 +346,12 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = eval.isNull match { + case TrueLiteral => FalseLiteral + case FalseLiteral => TrueLiteral + case v => JavaCode.isNullExpression(s"!$v") + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -416,8 +421,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, + extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" |do { @@ -436,11 +441,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate ev.copy(code = s""" - |${ctx.JAVA_INT} $nonnull = 0; + |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes |} while (false); - |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64da9bb9cdec..f974fd81fc78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.objects -import java.lang.reflect.Modifier +import java.lang.reflect.{Method, Modifier} +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -27,13 +28,15 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils /** * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. @@ -59,16 +62,16 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") + JavaCode.isNullGlobal(resultIsNull) } else { - "false" + FalseLiteral } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") argValue } @@ -103,6 +106,93 @@ trait InvokeLike extends Expression with NonSQLExpression { (argCode, argValues.mkString(", "), resultIsNull) } + + /** + * Evaluate each argument with a given row, invoke a method with a given object and arguments, + * and cast a return value if the return type can be mapped to a Java Boxed type + * + * @param obj the object for the method to be called. If null, perform s static method call + * @param method the method object to be called + * @param arguments the arguments used for the method call + * @param input the row used for evaluating arguments + * @param dataType the data type of the return object + * @return the return object of a method call + */ + def invoke( + obj: Any, + method: Method, + arguments: Seq[Expression], + input: InternalRow, + dataType: DataType): Any = { + val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) + if (needNullCheck && args.exists(_ == null)) { + // return null if one of arguments is null + null + } else { + val ret = method.invoke(obj, args: _*) + val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType) + if (boxedClass.isDefined) { + boxedClass.get.cast(ret) + } else { + ret + } + } + } +} + +/** + * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]] + */ +trait SerializerSupport { + /** + * If true, Kryo serialization is used, otherwise the Java one is used + */ + val kryo: Boolean + + /** + * The serializer instance to be used for serialization/deserialization in interpreted execution + */ + lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo) + + /** + * Adds a immutable state to the generated class containing a reference to the serializer. + * @return a string containing the name of the variable referencing the serializer + */ + def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + val (serializerInstance, serializerInstanceClass) = { + if (kryo) { + ("kryoSerializer", + classOf[KryoSerializerInstance].getName) + } else { + ("javaSerializer", + classOf[JavaSerializerInstance].getName) + } + } + val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" + // Code to initialize the serializer + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + s""" + |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); + """.stripMargin) + serializerInstance + } +} + +object SerializerSupport { + /** + * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is + * `useKryo` is set to `true`) or a `JavaSerializerInstance`. + */ + def newSerializer(useKryo: Boolean): SerializerInstance = { + // try conf from env, otherwise create a new one + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val s = if (useKryo) { + new KryoSerializer(conf) + } else { + new JavaSerializer(conf) + } + s.newInstance() + } } /** @@ -129,15 +219,24 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + + override def eval(input: InternalRow): Any = { + invoke(null, method, arguments, input, dataType) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -146,12 +245,12 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = FalseLiteral "" } val evaluate = if (returnNullable) { - if (ctx.defaultValue(dataType) == "null") { + if (CodeGenerator.defaultValue(dataType) == "null") { s""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; @@ -159,7 +258,7 @@ case class StaticInvoke( } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -173,7 +272,7 @@ case class StaticInvoke( val code = s""" $argCode $prepareIsNull - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!$resultIsNull) { $evaluate } @@ -208,12 +307,11 @@ case class Invoke( propagateNull: Boolean = true, returnNullable : Boolean = true) extends InvokeLike { + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - private lazy val encodedFunctionName = TermName(functionName).encodedName.toString @transient lazy val method = targetObject.dataType match { @@ -227,8 +325,23 @@ case class Invoke( case _ => None } + override def eval(input: InternalRow): Any = { + val obj = targetObject.eval(input) + if (obj == null) { + // return null if obj is null + null + } else { + val invokeMethod = if (method.isDefined) { + method.get + } else { + obj.getClass.getDeclaredMethod(functionName, argClasses: _*) + } + invoke(obj, invokeMethod, arguments, input, dataType) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -255,11 +368,11 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" } else { s""" if ($funcResult != null) { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; } else { ${ev.isNull} = true; } @@ -275,7 +388,7 @@ case class Invoke( val code = s""" ${obj.code} boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { $argCode ${ev.isNull} = $resultIsNull; @@ -337,11 +450,35 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + sys.error(s"Couldn't find a valid constructor on $cls") + } + } + outerPointer.map { p => + val outerObj = p() + val d = outerObj.getClass +: paramTypes + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(outerObj +: args: _*) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(args: _*) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -358,7 +495,8 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ ev.copy(code = code) } @@ -381,19 +519,25 @@ case class UnwrapOption( override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val inputObject = child.eval(input) + if (inputObject == null) { + null + } else { + inputObject.asInstanceOf[Option[_]].orNull + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) val code = s""" ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -415,8 +559,7 @@ case class WrapOption(child: Expression, optType: DataType) override def inputTypes: Seq[AbstractDataType] = optType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = Option(child.eval(input)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) @@ -428,7 +571,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -440,12 +583,33 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + if (nullable && input.isNullAt(0)) { + null + } else { + accessor(input, 0) + } + } override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + val isNullValue = if (nullable) { + JavaCode.isNullVariable(isNull) + } else { + FalseLiteral + } + ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue) } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -538,15 +702,99 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[Seq[_]] + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].toSeq[Any](et) + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + + "resulting collection.") + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVarDataType) + val elementJavaType = CodeGenerator.javaType(loopVarDataType) ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -554,7 +802,7 @@ case class MapObjects private( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(lambdaFunction.dataType) + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -586,13 +834,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. @@ -621,7 +862,7 @@ case class MapObjects private( ( s"${genInputData.value}.numElements()", "", - ctx.getValue(genInputData.value, et, loopIndex) + CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => val it = ctx.freshName("it") @@ -635,7 +876,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -643,7 +884,8 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -695,7 +937,7 @@ case class MapObjects private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -792,8 +1034,41 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val keyConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) + private lazy val valueConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) + + private lazy val (newMapBuilderMethod, moduleField) = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + var i = 0 + while (i < result.numElements()) { + val key = keyConverter(keyArray.get(i, inputMapType.keyType)) + val value = valueConverter(valueArray.get(i, inputMapType.valueType)) + builder += Tuple2(key, value) + i += 1 + } + builder.result() + } else { + null + } + } override def dataType: DataType = ObjectType(collClass) @@ -806,10 +1081,10 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = ctx.javaType(mapType.keyType) + val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(mapType.valueType) + val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) @@ -825,10 +1100,11 @@ case class CatalystToExternalMap private( val valueArray = ctx.freshName("valueArray") val getKeyArray = s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" - val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" - val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + val getValueLoopVar = CodeGenerator.getValue( + valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -844,7 +1120,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { @@ -873,7 +1149,7 @@ case class CatalystToExternalMap private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -979,8 +1255,72 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) @@ -993,8 +1333,8 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = ctx.javaType(keyType) - val valueElementJavaType = ctx.javaType(valueType) + val keyElementJavaType = CodeGenerator.javaType(keyType) + val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) @@ -1009,8 +1349,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1024,22 +1364,24 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1047,12 +1389,12 @@ case class ExternalMapToCatalyst private( val arrayCls = classOf[GenericArrayData].getName val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = ctx.boxedType(keyConverter.dataType) - val convertedValueType = ctx.boxedType(valueConverter.dataType) + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) val code = s""" ${inputMap.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1101,8 +1443,10 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val values = children.map(_.eval(input)).toArray + new GenericRowWithSchema(values, schema) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName @@ -1132,7 +1476,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1142,44 +1486,23 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to serialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1194,42 +1517,25 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1247,21 +1553,60 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + // We don't call setter if input value is null. + if (paramVal != null) { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |$javaBeanInstance.$setterMethod(${fieldGen.value}); + |if (!${fieldGen.isNull}) { + | $javaBeanInstance.$setterMethod(${fieldGen.value}); + |} """.stripMargin } val initializeCode = ctx.splitExpressionsWithCurrentInputs( @@ -1326,7 +1671,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1346,11 +1691,19 @@ case class GetExternalRowField( override def dataType: DataType = ObjectType(classOf[Object]) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + override def eval(input: InternalRow): Any = { + val inputRow = child.eval(input).asInstanceOf[Row] + if (inputRow == null) { + throw new RuntimeException("The input external row cannot be null.") + } + if (inputRow.isNullAt(index)) { + throw new RuntimeException(errMsg) + } + inputRow.get(index) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. @@ -1369,7 +1722,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1384,13 +1737,36 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. @@ -1403,17 +1779,17 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => - s"$obj instanceof ${ctx.boxedType(dataType)}" + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } val code = s""" ${input.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af..8a06daa37132 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b469f5cb7586..f8c6dc4e6adc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -36,6 +36,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** @@ -157,7 +165,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, + ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => @@ -234,7 +243,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = ctx.javaType(value.dataType) + val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -262,8 +271,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, - returnType = ctx.JAVA_BYTE, + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" |do { @@ -347,8 +356,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = s""" |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -404,7 +413,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -460,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -468,7 +477,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -504,7 +513,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (ctx.isPrimitiveType(left.dataType) + if (CodeGenerator.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { @@ -614,7 +623,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbfc3..70186053617f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -82,8 +82,11 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = FalseLiteral) } + + override def freshCopy(): Rand = Rand(child) } object Rand { @@ -116,8 +119,11 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = FalseLiteral) } + + override def freshCopy(): Randn = Randn(child) } object Randn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f3e8f6de5897..ad0c0791d895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961..ea005a26a4c8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * An expression that concatenates multiple inputs into a single output. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * If any input is null, concat returns null. - */ -@ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private lazy val isBinaryMode: Boolean = dataType == BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } - } - - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - if (isBinaryMode) { - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - } else { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } - - val (concatenator, initCode) = if (isBinaryMode) { - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - } else { - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - override def toString: String = s"concat(${children.mkString(", ")})" - - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" -} - - /** * An expression that concatenates multiple input strings or array of strings into a single string, * using a given separator (the first child). @@ -196,7 +115,7 @@ case class ConcatWs(children: Seq[Expression]) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") - val idxInVararg = ctx.freshName("idxInVararg") + val idxVararg = ctx.freshName("idxInVararg") val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => @@ -206,7 +125,7 @@ case class ConcatWs(children: Seq[Expression]) if (eval.isNull == "true") { "" } else { - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") @@ -222,7 +141,7 @@ case class ConcatWs(children: Seq[Expression]) if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")}; } } """) @@ -247,20 +166,20 @@ case class ConcatWs(children: Seq[Expression]) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body - |return $idxInVararg; + |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( s""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; - int $idxInVararg = 0; + int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; $varargBuilds @@ -271,33 +190,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +238,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,13 +266,13 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, - returnType = ctx.JAVA_BOOLEAN, + returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |do { | $body |} while (false); @@ -360,12 +291,12 @@ case class Elt(children: Seq[Expression]) s""" |${index.code} |final int $indexVal = ${index.value}; - |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -1398,10 +1329,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => val value = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.value})" + s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + + s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } @@ -1422,7 +1353,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); @@ -1492,26 +1423,6 @@ case class StringRepeat(str: Expression, times: Expression) } } -/** - * Returns the reversed given string. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns the reversed given string.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - """) -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - /** * Returns a string consisting of n spaces. */ @@ -1641,19 +1552,19 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + + "binary data. The length of string data includes the trailing spaces. The length of binary " + + "data includes binary zeros.", examples = """ Examples: - > SELECT _FUNC_('Spark SQL'); - 9 - > SELECT CHAR_LENGTH('Spark SQL'); - 9 - > SELECT CHARACTER_LENGTH('Spark SQL'); - 9 + > SELECT _FUNC_('Spark SQL '); + 10 + > SELECT CHAR_LENGTH('Spark SQL '); + 10 + > SELECT CHARACTER_LENGTH('Spark SQL '); + 10 """) -// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1675,7 +1586,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn * A function that returns the bit length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the bit length of `expr` or number of bits in binary data.", + usage = "_FUNC_(expr) - Returns the bit length of string data or number of bits of binary data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); @@ -1696,13 +1607,16 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") } } + + override def prettyName: String = "bit_length" } /** * A function that returns the byte length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the byte length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the byte length of string data or number of bytes of binary " + + "data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); @@ -1723,6 +1637,8 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override def prettyName: String = "octet_length" } /** @@ -2093,7 +2009,8 @@ case class FormatNumber(x: Expression, d: Expression) val usLocale = "US" val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index dd13d9a3bba5..9fe2fb2b95e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -265,27 +265,6 @@ case class SpecifiedWindowFrame( } } -object SpecifiedWindowFrame { - /** - * @param hasOrderSpecification If the window spec has order by expressions. - * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return the default window frame. - */ - def defaultWindowFrame( - hasOrderSpecification: Boolean, - acceptWindowFrame: Boolean): SpecifiedWindowFrame = { - if (hasOrderSpecification && acceptWindowFrame) { - // If order spec is defined and the window function supports user specified window frames, - // the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - // Otherwise, the default frame is - // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - } -} - case class UnresolvedWindowExpression( child: Expression, windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { @@ -363,7 +342,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index d0185562c9cf..aacf1a44e2ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -160,7 +160,7 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { """) // scalastyle:on line.size.limit case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { - override def prettyName: String = "xpath_float" + override def prettyName: String = "xpath_double" override def dataType: DataType = DoubleType override def nullSafeEval(xml: Any, path: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa..3e8e6db1dbd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -18,10 +18,14 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} +import java.nio.channels.Channels +import java.nio.charset.Charset import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text +import sun.nio.cs.StreamDecoder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.UTF8String private[sql] object CreateJacksonParser extends Serializable { @@ -43,7 +47,48 @@ private[sql] object CreateJacksonParser extends Serializable { jsonFactory.createParser(record.getBytes, 0, record.getLength) } - def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + // Jackson parsers can be ranked according to their performance: + // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser + // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically + // by checking leading bytes of the array. + // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected + // automatically by analyzing first bytes of the input stream. + // 3. Reader based parser. This is the slowest parser used here but it allows to create + // a reader with specific encoding. + // The method creates a reader for an array with given encoding and sets size of internal + // decoding buffer according to size of input array. + private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { + val bais = new ByteArrayInputStream(in, 0, length) + val byteChannel = Channels.newChannel(bais) + val decodingBufferSize = Math.min(length, 8192) + val decoder = Charset.forName(enc).newDecoder() + + StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) + } + + def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { + val sd = getStreamDecoder(enc, record.getBytes, record.getLength) + jsonFactory.createParser(sd) + } + + def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(is) + } + + def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(new InputStreamReader(is, enc)) + } + + def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val ba = row.getBinary(0) + + jsonFactory.createParser(ba, 0, ba.length) + } + + def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val binary = row.getBinary(0) + val sd = getStreamDecoder(enc, binary, binary.length) + + jsonFactory.createParser(sd) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478..2579374e3f4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.json +import java.nio.charset.{Charset, StandardCharsets} import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -85,6 +86,46 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + sep + } + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None), it will be detected automatically + * when the multiLine option is set to `true`. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map { enc => + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) + val isBlacklisted = blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. + |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } + val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index eb06e4f304f0..9c413de752a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer +import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator( private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val lineSeparator: String = options.lineSeparatorInWrite + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -251,5 +254,8 @@ private[sql] class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } - def writeLineEnding(): Unit = gen.writeRaw('\n') + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index bd144c9575c7..a5a4a13eb608 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -357,7 +357,18 @@ class JacksonParser( } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => + // JSON parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index be0009ec8c76..db7d6d3254bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,39 +18,39 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** -* push down operations into [[CreateNamedStructLike]]. -*/ -object SimplifyCreateStructOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field extraction + * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + */ +object SimplifyExtractValueOps extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // One place where this optimization is invalid is an aggregation where the select + // list expression is a function of a grouping expression: + // + // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) + // + // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this + // optimization for Aggregates (although this misses some cases where the optimization + // can be made). + case a: Aggregate => a + case p => p.transformExpressionsUp { + // Remove redundant field extraction. case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => createNamedStructLike.valExprs(ordinal) - } - } -} -/** -* push down operations into [[CreateArray]]. -*/ -object SimplifyCreateArrayOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field selection (array of structs) - case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) => - // instead f selecting the field on the entire array, - // select it from each member of the array. - // pushing down the operation this way open other optimizations opportunities - // (i.e. struct(...,x,...).x) + // Remove redundant array indexing. + case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => + // Instead of selecting the field on the entire array, select it from each member + // of the array. Pushing down the operation this way may open other optimizations + // opportunities (i.e. struct(...,x,...).x) CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name)))) - // push down item selection. + + // Remove redundant map lookup. case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) => - // instead of creating the array and then selecting one row, - // remove array creation altgether. + // Instead of creating the array and then selecting one row, remove array creation + // altogether. if (idx >= 0 && idx < elems.size) { // valid index elems(idx) @@ -58,18 +58,7 @@ object SimplifyCreateArrayOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - } - } -} - -/** -* push down operations into [[CreateMap]]. -*/ -object SimplifyCreateMapOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems) } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d4b02c6e7d8..bfa61116a665 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateSerialization, RemoveRedundantAliases, RemoveRedundantProject, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps, + SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -140,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -155,7 +155,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -352,7 +354,6 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) @@ -636,10 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and - * LeftSemi joins. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -660,21 +662,51 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.constraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None - case None => - additionalConstraints.reduceOption(And) + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) + + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } - if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join + } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } } } @@ -734,6 +766,33 @@ object EliminateSorts extends Rule[LogicalPlan] { } } +/** + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false + } +} + /** * Removes filters that can be evaluated trivially. This can be done through the following ways: * 1) by eliding the filter for cases where it will always evaluate to `true`. @@ -795,7 +854,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) - if aggregate.aggregateExpressions.forall(_.deterministic) => + if aggregate.aggregateExpressions.forall(_.deterministic) + && aggregate.groupingExpressions.nonEmpty => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { @@ -1107,23 +1167,29 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { */ def isCartesianProduct(join: Join): Boolean = { val conditions = join.condition.map(splitConjunctivePredicates).getOrElse(Nil) - !conditions.map(_.references).exists(refs => refs.exists(join.left.outputSet.contains) - && refs.exists(join.right.outputSet.contains)) + + conditions match { + case Seq(Literal.FalseLiteral) | Seq(Literal(null, BooleanType)) => false + case _ => !conditions.map(_.references).exists(refs => + refs.exists(join.left.outputSet.contains) && refs.exists(join.right.outputSet.contains)) + } } def apply(plan: LogicalPlan): LogicalPlan = if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } @@ -1221,7 +1287,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } - Aggregate(keys, aggCols, child) + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + Aggregate(nonemptyKeys, aggCols, child) } } @@ -1292,8 +1364,27 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { */ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.size > 1 => val newGrouping = ExpressionSet(grouping).toSeq - a.copy(groupingExpressions = newGrouping) + if (newGrouping.size == grouping.size) { + a + } else { + a.copy(groupingExpressions = newGrouping) + } + } +} + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index a6e5aa6daca6..c3fdb924243d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Collapse plans consisting empty local relations generated by [[PruneFilters]]. @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ -object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { +object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport { private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = - plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + + override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 89bfcee078fb..45edf266bbce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,18 +46,27 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case Except(left, right) if isEligible(left, right) => - Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + case e @ Except(left, right) if isEligible(left, right) => + val newCondition = transformCondition(left, skipProject(right)) + newCondition.map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { val filterCondition = InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { + Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + } else { + None + } } // TODO: This can be further extended in the future. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 1f20b7661489..2aa762e2595a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -187,11 +187,11 @@ object StarSchemaDetection extends PredicateHelper { stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { + val colStats = stats.attributeStats.get(col).get + if (!colStats.hasCountStats || colStats.nullCount.get > 0) { false } else { - val distinctCount = colStats.get.distinctCount + val distinctCount = colStats.distinctCount.get val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) // ndvMaxErr adjusted based on TPCDS 1TB data results relDiff <= conf.ndvMaxError * 2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc3..1c0b7bd80680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2673bea648d0..709db6d8bec7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -369,13 +369,14 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + case _ => + sys.error(s"Unexpected operator in scalar subquery: $lp") } val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId) + resultMap.getOrElse(plan.output.head.exprId, None) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a87..64eed2388458 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** @@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) + val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e..89347f4b1f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cc391aae5578..626f90570719 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -199,7 +199,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { case logical.Aggregate(groupingExpressions, resultExpressions, child) => @@ -213,7 +213,10 @@ object PhysicalAggregation { expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression - if (!equivalentAggregateExpressions.addExpr(agg)) => agg + if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF + if PythonUDF.isGroupAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -241,6 +244,10 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + // Similar to AggregateExpression + case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ddf2cbf2ab91..64cb8c726772 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -103,7 +103,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT var changed = false @inline def transformExpression(e: Expression): Expression = { - val newE = f(e) + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } if (newE.fastEquals(e)) { e } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d73d7e73f28d..8c4828a4cef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -43,10 +44,17 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], - data: Seq[InternalRow] = Nil, - // Indicates whether this relation has data from a streaming source. - override val isStreaming: Boolean = false) +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ +case class LocalRelation( + output: Seq[Attribute], + data: Seq[InternalRow] = Nil, + // Indicates whether this relation has data from a streaming source. + override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. @@ -70,7 +78,7 @@ case class LocalRelation(output: Seq[Attribute], } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5..e487693927ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,111 +118,18 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. + * Refreshes (or invalidates) any metadata/data cached in the plan recursively. */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } - } + def refresh(): Unit = children.foreach(_.refresh()) /** - * Refreshes (or invalidates) any metadata/data cached in the plan recursively. + * Returns the output ordering that this plan generates. */ - def refresh(): Unit = children.foreach(_.refresh()) + def outputOrdering: Seq[SortOrder] = Nil } /** @@ -247,12 +158,15 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) + allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } @@ -271,3 +185,7 @@ abstract class BinaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Seq(left, right) } + +abstract class OrderPreservingUnaryNode extends UnaryNode { + override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index e0748043c46e..2c248d74869c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical /** - * A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties. + * A visitor pattern for traversing a [[LogicalPlan]] tree and computing some properties. */ trait LogicalPlanVisitor[T] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38de..cc352c59dff8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For @@ -32,7 +32,7 @@ trait QueryPlanConstraints { self: LogicalPlan => ExpressionSet( validConstraints .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) .filter { c => c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic } @@ -51,13 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { + + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. + */ + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints + } + + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -93,61 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) - // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) - var inferredConstraints = Set.empty[Expression] - aliasedConstraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = aliasedConstraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - /** - * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. - * Thus non-converging inference can be prevented. - * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions. - * Also, the size of constraints is reduced without losing any information. - * When the inferred filters are pushed down the operators that generate the alias, - * the alias names used in filters are replaced by the aliased expressions. - */ - private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) - : Set[Expression] = { - val attributesInEqualTo = constraints.flatMap { - case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil - case _ => Nil - } - var aliasedConstraints = constraints - attributesInEqualTo.foreach { a => - if (aliasMap.contains(a)) { - val child = aliasMap.get(a).get - aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) - } - } - aliasedConstraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 96b199d7f20b..b3a48860aa63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -27,6 +27,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} @@ -79,11 +80,10 @@ case class Statistics( /** * Statistics collected for a column. * - * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * 1. The JVM data type stored in min/max is the internal data type for the corresponding * Catalyst data type. For example, the internal type of DateType is Int, and that the internal * type of TimestampType is Long. - * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -95,240 +95,32 @@ case class Statistics( * @param histogram histogram of the values */ case class ColumnStat( - distinctCount: BigInt, - min: Option[Any], - max: Option[Any], - nullCount: BigInt, - avgLen: Long, - maxLen: Long, + distinctCount: Option[BigInt] = None, + min: Option[Any] = None, + max: Option[Any] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, histogram: Option[Histogram] = None) { - // We currently don't store min/max for binary/string type. This can change in the future and - // then we need to remove this require. - require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) - require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - - /** - * Returns a map from string to string that can be used to serialize the column stats. - * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. min/max values are converted to the external data type. For - * example, for DateType we store java.sql.Date, and for TimestampType we store - * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. - * - * As part of the protocol, the returned map always contains a key called "version". - * In the case min/max values are null (None), they won't appear in the map. - */ - def toMap(colName: String, dataType: DataType): Map[String, String] = { - val map = new scala.collection.mutable.HashMap[String, String] - map.put(ColumnStat.KEY_VERSION, "1") - map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) - map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) - map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) - map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } - histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } - map.toMap - } - - /** - * Converts the given value from Catalyst data type to string representation of external - * data type. - */ - private def toExternalString(v: Any, colName: String, dataType: DataType): String = { - val externalValue = dataType match { - case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) - case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) - case BooleanType | _: IntegralType | FloatType | DoubleType => v - case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal - // This version of Spark does not use min/max for binary/string types so we ignore it. - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $colName of data type: $dataType.") - } - externalValue.toString - } - -} + // Are distinctCount and nullCount statistics defined? + val hasCountStats = distinctCount.isDefined && nullCount.isDefined + // Are min and max statistics defined? + val hasMinMaxStats = min.isDefined && max.isDefined -object ColumnStat extends Logging { - - // List of string keys used to serialize ColumnStat - val KEY_VERSION = "version" - private val KEY_DISTINCT_COUNT = "distinctCount" - private val KEY_MIN_VALUE = "min" - private val KEY_MAX_VALUE = "max" - private val KEY_NULL_COUNT = "nullCount" - private val KEY_AVG_LEN = "avgLen" - private val KEY_MAX_LEN = "maxLen" - private val KEY_HISTOGRAM = "histogram" - - /** Returns true iff the we support gathering column statistics on column of the given type. */ - def supportsType(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case BooleanType => true - case DateType => true - case TimestampType => true - case BinaryType | StringType => true - case _ => false - } - - /** Returns true iff the we support gathering histogram on column of the given type. */ - def supportsHistogram(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case DateType => true - case TimestampType => true - case _ => false - } - - /** - * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats - * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. - */ - def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { - try { - Some(ColumnStat( - distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), - // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - nullCount = BigInt(map(KEY_NULL_COUNT).toLong), - avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, - histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) - )) - } catch { - case NonFatal(e) => - logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) - None - } - } - - /** - * Converts from string representation of external data type to the corresponding Catalyst data - * type. - */ - private def fromExternalString(s: String, name: String, dataType: DataType): Any = { - dataType match { - case BooleanType => s.toBoolean - case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) - case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) - case ByteType => s.toByte - case ShortType => s.toShort - case IntegerType => s.toInt - case LongType => s.toLong - case FloatType => s.toFloat - case DoubleType => s.toDouble - case _: DecimalType => Decimal(s) - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $name of data type: $dataType.") - } - } - - /** - * Constructs an expression to compute column statistics for a given column. - * - * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, - * distinctCountsForIntervals: Array[Long] - * - * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and - * as a result should stay in sync with it. - */ - def statExprs( - col: Attribute, - conf: SQLConf, - colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { - def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => - expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } - }) - val one = Literal(1, LongType) - - // the approximate ndv (num distinct value) should never be larger than the number of rows - val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) - val numNulls = Subtract(Count(one), numNonNulls) - val defaultSize = Literal(col.dataType.defaultSize, LongType) - val nullArray = Literal(null, ArrayType(LongType)) - - def fixedLenTypeStruct: CreateNamedStruct = { - val genHistogram = - ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) - val intervalNdvsExpr = if (genHistogram) { - ApproxCountDistinctForIntervals(col, - Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) - } else { - nullArray - } - // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, - defaultSize, defaultSize, intervalNdvsExpr) - } - - col.dataType match { - case _: IntegralType => fixedLenTypeStruct - case _: DecimalType => fixedLenTypeStruct - case DoubleType | FloatType => fixedLenTypeStruct - case BooleanType => fixedLenTypeStruct - case DateType => fixedLenTypeStruct - case TimestampType => fixedLenTypeStruct - case BinaryType | StringType => - // For string and binary type, we don't compute min, max or histogram - val nullLit = Literal(null, col.dataType) - struct( - ndv, nullLit, nullLit, numNulls, - // Set avg/max size to default size if all the values are null or there is no value. - Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), - nullArray) - case _ => - throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") - } - } - - /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ - def rowToColumnStat( - row: InternalRow, - attr: Attribute, - rowCount: Long, - percentiles: Option[ArrayData]): ColumnStat = { - // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. - val cs = ColumnStat( - distinctCount = BigInt(row.getLong(0)), - // for string/binary min/max, get should return null - min = Option(row.get(1, attr.dataType)), - max = Option(row.get(2, attr.dataType)), - nullCount = BigInt(row.getLong(3)), - avgLen = row.getLong(4), - maxLen = row.getLong(5) - ) - if (row.isNullAt(6)) { - cs - } else { - val ndvs = row.getArray(6).toLongArray() - assert(percentiles.get.numElements() == ndvs.length + 1) - val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) - // Construct equi-height histogram - val bins = ndvs.zipWithIndex.map { case (ndv, i) => - HistogramBin(endpoints(i), endpoints(i + 1), ndv) - } - val nonNullRows = rowCount - cs.nullCount - val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) - cs.copy(histogram = Some(histogram)) - } - } + // Are avgLen and maxLen statistics defined? + val hasLenStats = avgLen.isDefined && maxLen.isDefined + def toCatalogColumnStat(colName: String, dataType: DataType): CatalogColumnStat = + CatalogColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 95e099c340af..3bf32ef7884e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { * This node is inserted at the top of a subquery when it is optimized. This makes sure we can * recognize a subquery as such, and it allows us to write subquery aware transformations. */ -case class Subquery(child: LogicalPlan) extends UnaryNode { +case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } -case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -125,7 +126,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) - extends UnaryNode with PredicateHelper { + extends OrderPreservingUnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -469,6 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def outputOrdering: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -522,6 +524,15 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( @@ -675,17 +686,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } @@ -728,7 +756,7 @@ object Limit { * * See [[Limit]] for more information. */ -case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { @@ -744,7 +772,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -764,7 +792,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, child: LogicalPlan) - extends UnaryNode { + extends OrderPreservingUnaryNode { override def doCanonicalize(): LogicalPlan = child.canonicalized @@ -903,6 +931,7 @@ case class Deduplicate( * This analysis barrier will be removed at the end of analysis stage. */ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override protected def innerChildren: Seq[LogicalPlan] = Seq(child) override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = child.isStreaming override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index c41fac4015ec..111c594a53e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -32,13 +32,15 @@ object AggregateEstimation { val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => - e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) + e.isInstanceOf[Attribute] && + childStats.attributeStats.get(e.asInstanceOf[Attribute]).exists(_.hasCountStats) } if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( - (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) + (res, expr) => res * + childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get) outputRows = if (agg.groupingExpressions.isEmpty) { // If there's no group-by columns, the output is a single row containing values of aggregate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index ca0775a2e840..b6c16079d198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.plans.logical._ /** - * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. + * A [[LogicalPlanVisitor]] that computes the statistics for the cost-based optimizer. */ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index d793f77413d1..211a2a071737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -38,9 +37,18 @@ object EstimationUtils { } } + /** Check if each attribute has column stat containing distinct and null counts + * in the corresponding statistic. */ + def columnStatsWithCountsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.get(attr).map(_.hasCountStats).getOrElse(false) + } + } + + /** Statistics for a Column containing only NULLs. */ def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, - avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(rowCount), + avgLen = Some(dataType.defaultSize), maxLen = Some(dataType.defaultSize)) } /** @@ -63,29 +71,33 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => - if (attrStats.contains(attr)) { + 8 + attributes.map { attr => + if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 + attrStats(attr).avgLen.get + 8 + 4 case _ => - attrStats(attr).avgLen + attrStats(attr).avgLen.get } } else { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4cc32de2d32d..5a3eeefaedb1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, isNull: Boolean, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) { logDebug("[CBO] No statistics for " + attr) return None } @@ -234,14 +234,14 @@ case class FilterEstimation(plan: Filter) extends Logging { val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble + (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble } if (update) { val newStats = if (isNull) { - colStat.copy(distinctCount = 0, min = None, max = None) + colStat.copy(distinctCount = Some(0), min = None, max = None) } else { - colStat.copy(nullCount = 0) + colStat.copy(nullCount = Some(0)) } colStatsMap.update(attr, newStats) } @@ -322,17 +322,21 @@ case class FilterEstimation(plan: Filter) extends Logging { // value. val newStats = attr.dataType match { case StringType | BinaryType => - colStat.copy(distinctCount = 1, nullCount = 0) + colStat.copy(distinctCount = Some(1), nullCount = Some(0)) case _ => - colStat.copy(distinctCount = 1, min = Some(literal.value), - max = Some(literal.value), nullCount = 0) + colStat.copy(distinctCount = Some(1), min = Some(literal.value), + max = Some(literal.value), nullCount = Some(0)) } colStatsMap.update(attr, newStats) } if (colStat.histogram.isEmpty) { - // returns 1/ndv if there is no histogram - Some(1.0 / colStat.distinctCount.toDouble) + if (!colStat.distinctCount.isEmpty) { + // returns 1/ndv if there is no histogram + Some(1.0 / colStat.distinctCount.get.toDouble) + } else { + None + } } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -378,19 +382,23 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, hSet: Set[Any], update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.hasDistinctCount(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount + val ndv = colStat.distinctCount.get val dataType = attr.dataType var newNdv = ndv // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -407,16 +415,20 @@ case class FilterEstimation(plan: Filter) extends Logging { // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), - max = Some(newMax), nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), + max = Some(newMax), nullCount = Some(0)) colStatsMap.update(attr, newStats) } // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) colStatsMap.update(attr, newStats) } } @@ -443,12 +455,17 @@ case class FilterEstimation(plan: Filter) extends Logging { literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] val max = statsInterval.max val min = statsInterval.min - val ndv = colStat.distinctCount.toDouble + val ndv = colStat.distinctCount.get.toDouble // determine the overlapping degree between predicate interval and column's interval val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) @@ -520,8 +537,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = colStat.copy(distinctCount = ceil(ndv * percent), - min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)), + min = newMin, max = newMax, nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -637,11 +654,11 @@ case class FilterEstimation(plan: Filter) extends Logging { attrRight: Attribute, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attrLeft)) { + if (!colStatsMap.hasCountStats(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) return None } - if (!colStatsMap.contains(attrRight)) { + if (!colStatsMap.hasCountStats(attrRight)) { logDebug("[CBO] No statistics for " + attrRight) return None } @@ -668,7 +685,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val minRight = statsIntervalRight.min // determine the overlapping degree between predicate interval and column's interval - val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val allNotNull = (colStatLeft.nullCount.get == 0) && (colStatRight.nullCount.get == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right // - no overlap: @@ -707,14 +724,14 @@ case class FilterEstimation(plan: Filter) extends Logging { case _: EqualTo => ((maxLeft < minRight) || (maxRight < minLeft), (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) case _: EqualNullSafe => // For null-safe equality, we use a very restrictive condition to evaluate its overlap. // If null values exists, we set it to partial overlap. (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) } @@ -731,9 +748,9 @@ case class FilterEstimation(plan: Filter) extends Logging { if (update) { // Need to adjust new min/max after the filter condition is applied - val ndvLeft = BigDecimal(colStatLeft.distinctCount) + val ndvLeft = BigDecimal(colStatLeft.distinctCount.get) val newNdvLeft = ceil(ndvLeft * percent) - val ndvRight = BigDecimal(colStatRight.distinctCount) + val ndvRight = BigDecimal(colStatRight.distinctCount.get) val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max @@ -817,10 +834,10 @@ case class FilterEstimation(plan: Filter) extends Logging { } } - val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + val newStatsLeft = colStatLeft.copy(distinctCount = Some(newNdvLeft), min = newMinLeft, max = newMaxLeft) colStatsMap(attrLeft) = newStatsLeft - val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + val newStatsRight = colStatRight.copy(distinctCount = Some(newNdvRight), min = newMinRight, max = newMaxRight) colStatsMap(attrRight) = newStatsRight } @@ -849,17 +866,35 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) /** - * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in - * originalMap, because updatedMap has the latest (updated) column stats. + * Gets an Option of column stat for the given attribute. + * Prefer the column stat in updatedMap than that in originalMap, + * because updatedMap has the latest (updated) column stats. */ - def apply(a: Attribute): ColumnStat = { + def get(a: Attribute): Option[ColumnStat] = { if (updatedMap.contains(a.exprId)) { - updatedMap(a.exprId)._2 + updatedMap.get(a.exprId).map(_._2) } else { - originalMap(a) + originalMap.get(a) } } + def hasCountStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + def hasDistinctCount(a: Attribute): Boolean = + get(a).map(_.distinctCount.isDefined).getOrElse(false) + + def hasMinMaxStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + /** + * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in + * originalMap, because updatedMap has the latest (updated) column stats. + */ + def apply(a: Attribute): ColumnStat = { + get(a).get + } + /** Updates column stats in updatedMap. */ def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) @@ -871,11 +906,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) - val newNdv = if (colStat.distinctCount > 1) { + val newNdv = if (colStat.distinctCount.isEmpty) { + // No NDV in the original stats. + None + } else if (colStat.distinctCount.get > 1) { // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows // decreases; otherwise keep it unchanged. - EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get)) } else { // no need to scale down since it is already down to 1 (for skewed distribution case) colStat.distinctCount diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f0294a424670..2543e38a92c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -85,7 +85,8 @@ case class JoinEstimation(join: Join) extends Logging { // 3. Update statistics based on the output of join val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) - val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val attributesWithStat = join.output.filter(a => + inputAttrStats.get(a).map(_.hasCountStats).getOrElse(false)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { @@ -106,10 +107,10 @@ case class JoinEstimation(join: Join) extends Logging { case FullOuter => fromLeft.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + rightRows))) } ++ fromRight.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + leftRows))) } case _ => assert(joinType == Inner || joinType == Cross) @@ -219,19 +220,27 @@ case class JoinEstimation(join: Join) extends Logging { private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, - newMin: Option[Any], - newMax: Option[Any]): (BigInt, ColumnStat) = { + min: Option[Any], + max: Option[Any]): (BigInt, ColumnStat) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + val maxNdv = leftKeyStat.distinctCount.get.max(rightKeyStat.distinctCount.get) // Compute cardinality by the basic formula. val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) // Get the intersected column stat. - val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + val newNdv = Some(leftKeyStat.distinctCount.get.min(rightKeyStat.distinctCount.get)) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(newNdv, min, max, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -267,9 +276,17 @@ case class JoinEstimation(join: Join) extends Logging { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(Some(ceil(totalNdv)), newMin, newMax, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -292,10 +309,14 @@ case class JoinEstimation(join: Join) extends Logging { } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount - val newNdv = if (join.left.outputSet.contains(a)) { - updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv) + val newNdv = if (oldNdv.isDefined) { + Some(if (join.left.outputSet.contains(a)) { + updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get) + } else { + updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get) + }) } else { - updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv) + None } val newColStat = oldColStat.copy(distinctCount = newNdv) // TODO: support nullCount updates for specific outer joins @@ -313,7 +334,7 @@ case class JoinEstimation(join: Join) extends Logging { // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to // support it in the future by using `nullCount` in column stats. case (lk: AttributeReference, rk: AttributeReference) - if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) + if columnStatsWithCountsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 5e1c4e0bd606..ee43f9126386 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { @@ -48,8 +48,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } /** - * For leaf nodes, use its computeStats. For other nodes, we assume the size in bytes is the - * sum of all of the children's. + * For leaf nodes, use its `computeStats`. For other nodes, we assume the size in bytes is the + * product of all of the children's `computeStats`. */ override def default(p: LogicalPlan): Statistics = p match { case p: LeafNode => p.computeStats() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e57c842ce2a3..4d9a9925fe3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * - Intra-partition ordering of data: In this case the distribution describes guarantees made * about how tuples are distributed within a single partition. */ -sealed trait Distribution +sealed trait Distribution { + /** + * The required number of partitions for this distribution. If it's None, then any number of + * partitions is allowed for this distribution. + */ + def requiredNumPartitions: Option[Int] + + /** + * Creates a default partitioning for this distribution, which can satisfy this distribution while + * matching the given number of partitions. + */ + def createPartitioning(numPartitions: Int): Partitioning +} /** * Represents a distribution where no promises are made about co-location of data. */ -case object UnspecifiedDistribution extends Distribution +case object UnspecifiedDistribution extends Distribution { + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.") + } +} /** * Represents a distribution that only has a single partition and all tuples of the dataset * are co-located. */ -case object AllTuples extends Distribution +case object AllTuples extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.") + SinglePartition + } +} /** * Represents data where tuples that share the same values for the `clustering` @@ -51,12 +76,41 @@ case object AllTuples extends Distribution */ case class ClusteredDistribution( clustering: Seq[Expression], - numPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(clustering, numPartitions) + } +} + +/** + * Represents data where tuples have been clustered according to the hash of the given + * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * [[HashPartitioning]] can satisfy this distribution. + * + * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the + * number of partitions, this distribution strictly requires which partition the tuple should be in. + */ +case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + HashPartitioning(expressions, numPartitions) + } } /** @@ -73,48 +127,33 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - // TODO: This is not really valid... - def clustering: Set[Expression] = ordering.map(_.child).toSet + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + RangePartitioning(ordering, numPartitions) + } } /** * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, + "The default partitioning of BroadcastDistribution can only have 1 partition.") + BroadcastPartitioning(mode) + } +} /** - * Describes how an operator's output is split across partitions. The `compatibleWith`, - * `guarantees`, and `satisfies` methods describe relationships between child partitionings, - * target partitionings, and [[Distribution]]s. These relations are described more precisely in - * their individual method docs, but at a high level: - * - * - `satisfies` is a relationship between partitionings and distributions. - * - `compatibleWith` is relationships between an operator's child output partitionings. - * - `guarantees` is a relationship between a child's existing output partitioning and a target - * output partitioning. - * - * Diagrammatically: - * - * +--------------+ - * | Distribution | - * +--------------+ - * ^ - * | - * satisfies - * | - * +--------------+ +--------------+ - * | Child | | Target | - * +----| Partitioning |----guarantees--->| Partitioning | - * | +--------------+ +--------------+ - * | ^ - * | | - * | compatibleWith - * | | - * +------------+ - * + * Describes how an operator's output is split across partitions. It has 2 major properties: + * 1. number of partitions. + * 2. if it can satisfy a given distribution. */ -sealed trait Partitioning { +trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -123,113 +162,35 @@ sealed trait Partitioning { * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). - */ - def satisfies(required: Distribution): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] - * guarantees the same partitioning scheme described by `other`. - * - * Compatibility of partitionings is only checked for operators that have multiple children - * and that require a specific child output [[Distribution]], such as joins. - * - * Intuitively, partitionings are compatible if they route the same partitioning key to the same - * partition. For instance, two hash partitionings are only compatible if they produce the same - * number of output partitionings and hash records according to the same hash function and - * same partitioning key schema. - * - * Put another way, two partitionings are compatible with each other if they satisfy all of the - * same distribution guarantees. - */ - def compatibleWith(other: Partitioning): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees - * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning - * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance - * optimization to allow the exchange planner to avoid redundant repartitionings. By default, - * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number - * of partitions, same strategy (range or hash), etc). - * - * In order to enable more aggressive optimization, this strict equality check can be relaxed. - * For example, say that the planner needs to repartition all of an operator's children so that - * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children - * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens - * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; - * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` - * [[SinglePartition]]. - * - * The SinglePartition example given above is not particularly interesting; guarantees' real - * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion - * of null-safe partitionings, under which partitionings can specify whether rows whose - * partitioning keys contain null values will be grouped into the same partition or whether they - * will have an unknown / random distribution. If a partitioning does not require nulls to be - * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered - * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot - * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a - * symmetric relation. * - * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows - * produced by `A` could have also been produced by `B`. + * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if + * the [[Partitioning]] only have one partition. Implementations can overwrite this method with + * special logic. */ - def guarantees(other: Partitioning): Boolean = this == other -} - -object Partitioning { - def allCompatible(partitionings: Seq[Partitioning]): Boolean = { - // Note: this assumes transitivity - partitionings.sliding(2).map { - case Seq(a) => true - case Seq(a, b) => - if (a.numPartitions != b.numPartitions) { - assert(!a.compatibleWith(b) && !b.compatibleWith(a)) - false - } else { - a.compatibleWith(b) && b.compatibleWith(a) - } - }.forall(_ == true) - } -} - -case class UnknownPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { + def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false } +case class UnknownPartitioning(numPartitions: Int) extends Partitioning + /** * Represents a partitioning where rows are distributed evenly across output partitions * by starting from a random target partition number and distributing rows in a round-robin * fashion. This partitioning is used when implementing the DataFrame.repartition() operator. */ -case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false -} +case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) + case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } - - override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - - override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering, desiredPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case h: HashClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } /** @@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, desiredPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } } @@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) - /** - * Returns true if any `partitioning` of this collection is compatible with - * the given [[Partitioning]]. - */ - override def compatibleWith(other: Partitioning): Boolean = - partitionings.exists(_.compatibleWith(other)) - - /** - * Returns true if any `partitioning` of this collection guarantees - * the given [[Partitioning]]. - */ - override def guarantees(other: Partitioning): Boolean = - partitionings.exists(_.guarantees(other)) - override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } @@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { case BroadcastDistribution(m) if m == mode => true case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning(m) if m == mode => true - case _ => false - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala new file mode 100644 index 000000000000..62f7541150a6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.rules + +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + +case class QueryExecutionMetering() { + private val timeMap = AtomicLongMap.create[String]() + private val numRunsMap = AtomicLongMap.create[String]() + private val numEffectiveRunsMap = AtomicLongMap.create[String]() + private val timeEffectiveRunsMap = AtomicLongMap.create[String]() + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + timeMap.clear() + numRunsMap.clear() + numEffectiveRunsMap.clear() + timeEffectiveRunsMap.clear() + } + + def totalTime: Long = { + timeMap.sum() + } + + def totalNumRuns: Long = { + numRunsMap.sum() + } + + def incExecutionTimeBy(ruleName: String, delta: Long): Unit = { + timeMap.addAndGet(ruleName, delta) + } + + def incTimeEffectiveExecutionBy(ruleName: String, delta: Long): Unit = { + timeEffectiveRunsMap.addAndGet(ruleName, delta) + } + + def incNumEffectiveExecution(ruleName: String): Unit = { + numEffectiveRunsMap.incrementAndGet(ruleName) + } + + def incNumExecution(ruleName: String): Unit = { + numRunsMap.incrementAndGet(ruleName) + } + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val map = timeMap.asMap().asScala + val maxLengthRuleNames = map.keys.map(_.toString.length).max + + val colRuleName = "Rule".padTo(maxLengthRuleNames, " ").mkString + val colRunTime = "Effective Time / Total Time".padTo(len = 47, " ").mkString + val colNumRuns = "Effective Runs / Total Runs".padTo(len = 47, " ").mkString + + val ruleMetrics = map.toSeq.sortBy(_._2).reverseMap { case (name, time) => + val timeEffectiveRun = timeEffectiveRunsMap.get(name) + val numRuns = numRunsMap.get(name) + val numEffectiveRun = numEffectiveRunsMap.get(name) + + val ruleName = name.padTo(maxLengthRuleNames, " ").mkString + val runtimeValue = s"$timeEffectiveRun / $time".padTo(len = 47, " ").mkString + val numRunValue = s"$numEffectiveRun / $numRuns".padTo(len = 47, " ").mkString + s"$ruleName $runtimeValue $numRunValue" + }.mkString("\n", "\n", "") + + s""" + |=== Metrics of Analyzer/Optimizer Rules === + |Total number of runs: $totalNumRuns + |Total time: ${totalTime / 1000000000D} seconds + | + |$colRuleName $colRunTime $colNumRuns + |$ruleMetrics + """.stripMargin + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 7e4b784033bf..dccb44ddebfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.rules -import scala.collection.JavaConverters._ - -import com.google.common.util.concurrent.AtomicLongMap - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode @@ -28,18 +24,16 @@ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.util.Utils object RuleExecutor { - protected val timeMap = AtomicLongMap.create[String]() - - /** Resets statistics about time spent running specific rules */ - def resetTime(): Unit = timeMap.clear() + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val map = timeMap.asMap().asScala - val maxSize = map.keys.map(_.toString.length).max - map.toSeq.sortBy(_._2).reverseMap { case (k, v) => - s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n", "\n", "") + queryExecutionMeter.dumpTimeSpent() + } + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + queryExecutionMeter.resetMetrics() } } @@ -77,6 +71,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ def execute(plan: TreeType): TreeType = { var curPlan = plan + val queryExecutionMetrics = RuleExecutor.queryExecutionMeter batches.foreach { batch => val batchStartPlan = curPlan @@ -91,15 +86,18 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { + queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) + queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} """.stripMargin) } + queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) + queryExecutionMetrics.incNumExecution(rule.ruleName) // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { @@ -135,9 +133,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (!batchStartPlan.fastEquals(curPlan)) { logDebug( s""" - |=== Result of Batch ${batch.name} === - |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} - """.stripMargin) + |=== Result of Batch ${batch.name} === + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} + """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41d639f..104b42861484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def toSeq[T](dataType: DataType): IndexedSeq[T] = + new ArrayDataIndexedSeq[T](this, dataType) + def setNullAt(i: Int): Unit def update(i: Int, value: Any): Unit @@ -137,30 +141,55 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType).asInstanceOf[T] + values(i) = accessor(this, i).asInstanceOf[T] } i += 1 } values } - // todo: specialize this. def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { if (isNullAt(i)) { f(i, null) } else { - f(i, get(i, elementType)) + f(i, accessor(this, i)) } i += 1 } } } + +/** + * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData` + * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`, + * instead of `IndexedSeq[Int]`, in order to keep the null elements. + */ +class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] { + + private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType) + + override def apply(idx: Int): T = + if (0 <= idx && idx < arrayData.numElements()) { + if (arrayData.isNullAt(idx)) { + null.asInstanceOf[T] + } else { + accessor(arrayData, idx).asInstanceOf[T] + } + } else { + throw new IndexOutOfBoundsException( + s"Index $idx must be between 0 and the length of the ArrayData.") + } + + override def length: Int = arrayData.numElements() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c8..e646da0659e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -296,10 +296,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +435,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { @@ -870,24 +890,14 @@ object DateTimeUtils { * If time1 and time2 having the same day of month, or both are the last day of month, * it returns an integer (time under a day will be ignored). * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * If `roundOff` is set to true, the result is rounded to 8 decimal places. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -898,16 +908,25 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff + } + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6a..3190e511e2cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -105,7 +108,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } @@ -192,10 +195,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +238,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala new file mode 100644 index 000000000000..4fe07a071c1c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.UUID + +import org.apache.commons.math3.random.MersenneTwister + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class is used to generate a UUID from Pseudo-Random Numbers. + * + * For the algorithm, see RFC 4122: A Universally Unique IDentifier (UUID) URN Namespace, + * section 4.4 "Algorithms for Creating a UUID from Truly Random or Pseudo-Random Numbers". + */ +case class RandomUUIDGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextUUID(): UUID = { + val mostSigBits = (random.nextLong() & 0xFFFFFFFFFFFF0FFFL) | 0x0000000000004000L + val leastSigBits = (random.nextLong() | 0x8000000000000000L) & 0xBFFFFFFFFFFFFFFFL + + new UUID(mostSigBits, leastSigBits) + } + + def getNextUUIDUTF8String(): UTF8String = UTF8String.fromString(getNextUUID().toString()) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af5..895e15075656 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -121,14 +123,12 @@ object SQLConf { .createWithDefault(10) val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") - .internal() .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") .booleanConf .createWithDefault(true) val COLUMN_BATCH_SIZE = buildConf("spark.sql.inMemoryColumnarStorage.batchSize") - .internal() .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + "memory utilization and compression, but risk OOMs when caching data.") .intConf @@ -141,6 +141,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CACHE_VECTORIZED_READER_ENABLED = + buildConf("spark.sql.inMemoryColumnarStorage.enableVectorizedReader") + .doc("Enables vectorized reader for columnar caching.") + .booleanConf + .createWithDefault(true) + val COLUMN_VECTOR_OFFHEAP_ENABLED = buildConf("spark.sql.columnVector.offheap.enabled") .internal() @@ -247,7 +253,7 @@ object SQLConf { val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") .internal() .doc("When true, the query optimizer will infer and propagate data constraints in the query " + - "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive " + "for certain kinds of query plans (such as those with a large number of predicates and " + "aliases) which might negatively impact overall runtime.") .booleanConf @@ -261,6 +267,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val FILE_COMRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor") + .internal() + .doc("When estimating the output data size of a table scan, multiply the file size with this " + + "factor as the estimated data size, in case the data is compressed in the file and lead to" + + " a heavily underestimated result.") + .doubleConf + .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0") + .createWithDefault(1.0) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -323,11 +338,14 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -335,6 +353,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") + .doc("If true, enables Parquet filter push-down optimization for Date. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -363,9 +388,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") + .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) @@ -373,20 +406,39 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default prior to Spark 2.3.") + "1.2.1. It is 'hive' by default prior to Spark 2.4.") .internal() .stringConf .checkValues(Set("hive", "native")) .createWithDefault("native") + val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc decoding.") + .booleanConf + .createWithDefault(true) + + val ORC_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.orc.columnarReaderBatchSize") + .doc("The number of rows to include in a orc vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") + .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + + "vectorized ORC reader.") + .internal() + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + - "when reading data stored in HDFS.") + "when reading data stored in HDFS. This configuration will be deprecated in the future " + + "releases and replaced by spark.files.ignoreMissingFiles.") .booleanConf .createWithDefault(false) @@ -435,6 +487,16 @@ object SQLConf { .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = + buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp") + .internal() + .doc("When true (default), compare Date with Timestamp after converting both sides to " + + "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + + "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + + "converting both sides to string. This config will be removed in spark 3.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -449,6 +511,14 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to curruptions.") + .booleanConf + .createWithDefault(true) + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) @@ -601,6 +671,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME = + buildConf("spark.sql.codegen.useIdInClassName") + .internal() + .doc("When true, embed the (whole-stage) codegen stage ID into " + + "the class name of the generated class as a suffix") + .booleanConf + .createWithDefault(true) + val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields") .internal() .doc("The maximum number of fields (including nested fields) that will be supported before" + @@ -626,12 +704,22 @@ object SQLConf { val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + - "codegen. When the compiled function exceeds this threshold, " + - "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + - "this is a limit in the OpenJDK JVM implementation.") + "codegen. When the compiled function exceeds this threshold, the whole-stage codegen is " + + "deactivated for this subtree of the current query plan. The default value is 65535, which " + + "is the largest bytecode size possible for a valid Java method. When running on HotSpot, " + + s"it may be preferable to set the value to ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} " + + "to match HotSpot's implementation.") .intConf - .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + .createWithDefault(65535) + + val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = + buildConf("spark.sql.codegen.splitConsumeFuncByOperator") + .internal() + .doc("When true, whole stage codegen would put the logic of consuming rows of each " + + "physical operator into individual methods, instead of a single big method. This can be " + + "used to avoid oversized function that can miss the opportunity of JIT optimization.") + .booleanConf + .createWithDefault(true) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") @@ -831,6 +919,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -843,6 +939,13 @@ object SQLConf { .intConf .createWithDefault(100) + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .internal() + .stringConf + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() @@ -866,7 +969,7 @@ object SQLConf { .internal() .doc("The number of bins when generating histograms.") .intConf - .checkValue(num => num > 1, "The number of bins must be large than 1.") + .checkValue(num => num > 1, "The number of bins must be larger than 1.") .createWithDefault(254) val PERCENTILE_ACCURACY = @@ -996,19 +1099,25 @@ object SQLConf { .intConf .createWithDefault(100) - val ARROW_EXECUTION_ENABLE = + val ARROW_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.enabled") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") + .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas, and " + + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + + "The following data types are unsupported: " + + "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .booleanConf + .createWithDefault(true) + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() .doc("When using Apache Arrow, limit the maximum number of records that can be written " + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") .intConf @@ -1036,6 +1145,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") + .internal() + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") + .booleanConf + .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = ConfigBuilder("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + @@ -1050,6 +1169,20 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = + buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation") + .internal() + .doc("When this option is set to true, creating managed tables with nonempty location " + + "is allowed. Otherwise, an analysis exception is thrown. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1066,6 +1199,60 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") + .internal() + .doc("A comma-separated list of fully qualified data source register class names for which" + + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") + .stringConf + .createWithDefault("") + + val DISABLED_V2_STREAMING_MICROBATCH_READERS = + buildConf("spark.sql.streaming.disabledV2MicroBatchReaders") + .internal() + .doc( + "A comma-separated list of fully qualified data source register class names for which " + + "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " + + "V1 Sources.") + .stringConf + .createWithDefault("") + + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + + val SORT_BEFORE_REPARTITION = + buildConf("spark.sql.execution.sortBeforeRepartition") + .internal() + .doc("When perform a repartition following a shuffle, the output row ordering would be " + + "nondeterministic. If some downstream stages fail and some tasks of the repartition " + + "stage retry, these tasks may generate different data, and that can lead to correctness " + + "issues. Turn on this config to insert a local sort before actually doing repartition " + + "to generate consistent repartition results. The performance of repartition() may go " + + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1087,6 +1274,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) @@ -1128,6 +1321,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) @@ -1146,12 +1342,20 @@ class SQLConf extends Serializable with Logging { def orcCompressionCodec: String = getConf(ORC_COMPRESSION) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + + def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) def targetPostShuffleInputSize: Long = @@ -1166,6 +1370,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) @@ -1179,12 +1385,17 @@ class SQLConf extends Serializable with Logging { def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def compareDateTimestampInTimestamp : Boolean = + getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME) + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) @@ -1193,6 +1404,9 @@ class SQLConf extends Serializable with Logging { def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) + def wholeStageSplitConsumeFuncByOperator: Boolean = + getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR) + def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) @@ -1204,8 +1418,12 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -1371,7 +1589,9 @@ class SQLConf extends Serializable with Logging { def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED) + + def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) @@ -1379,13 +1599,28 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + + def disabledV2StreamingMicroBatchReaders: String = + getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = + getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) + + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d6e0df12218a..0bef11659fc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -295,25 +295,31 @@ object DataType { } /** - * Returns true if the two data types share the same "shape", i.e. the types (including - * nullability) are the same, but the field names don't need to be the same. + * Returns true if the two data types share the same "shape", i.e. the types + * are the same, but the field names don't need to be the same. + * + * @param ignoreNullability whether to ignore nullability when comparing the types */ - def equalsStructurally(from: DataType, to: DataType): Boolean = { + def equalsStructurally( + from: DataType, + to: DataType, + ignoreNullability: Boolean = false): Boolean = { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurally(left.elementType, right.elementType) && - left.containsNull == right.containsNull + (ignoreNullability || left.containsNull == right.containsNull) case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType) && equalsStructurally(left.valueType, right.valueType) && - left.valueContainsNull == right.valueContainsNull + (ignoreNullability || left.valueContainsNull == right.valueContainsNull) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields) .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + equalsStructurally(l.dataType, r.dataType) && + (ignoreNullability || l.nullable == r.nullable) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6e050c18b8ac..ef3b67c0d48d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) @@ -136,10 +137,52 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } + private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { + case v: Short => fromBigDecimal(BigDecimal(v)) + case v: Int => fromBigDecimal(BigDecimal(v)) + case v: Long => fromBigDecimal(BigDecimal(v)) + case _ => forType(literal.dataType) + } + + private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = { + DecimalType(Math.max(d.precision, d.scale), d.scale) + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } + /** + * Scale adjustment implementation is based on Hive's one, which is itself inspired to + * SQLServer's one. In particular, when a result precision is greater than + * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * result from being truncated. + * + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + */ + private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { + // Assumptions: + assert(precision >= scale) + assert(scale >= 0) + + if (precision <= MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + DecimalType(precision, scale) + } else { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + val intDigits = precision - scale + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) + + DecimalType(MAX_PRECISION, adjustedScale) + } + } + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e3b0969283a8..362676b25212 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -104,6 +104,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) + /** + * Returns all field names in an array. This is an alias of `fieldNames`. + * + * @since 2.4.0 + */ + def names: Array[String] = fieldNames + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap @@ -264,7 +271,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def apply(name: String): StructField = { nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } /** @@ -277,7 +286,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") + s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -290,7 +300,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5a944e763e09..6af16e2dba10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def catalogString: String = sqlType.simpleString } +private[spark] object UserDefinedType { + /** + * Get the sqlType of a (potential) [[UserDefinedType]]. + */ + def sqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case _ => dt + } +} + /** * The user defined type in Python. * diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index b67c6f3e6e85..76930f936851 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -53,7 +54,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -89,13 +90,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. @@ -112,13 +113,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed199..2da87113c622 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +54,27 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 711887f02832..cd8bce623c5d 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,6 +24,8 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -74,9 +76,6 @@ public void testKnownByteArrayInputs() { Assert.assertEquals(0x739840CB819FA723L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); - // These tests currently fail in a big endian environment because the test data and expected - // answers are generated with little endian the assumptions. We could revisit this when Platform - // becomes endian aware. if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); @@ -94,6 +93,23 @@ public void testKnownByteArrayInputs() { hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); Assert.assertEquals(0xCAA65939306F1E21L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } else { + Assert.assertEquals(0x7F875412350ADDDCL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x564D279F524D8516L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0x7D9F07E27E0EB006L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x893CEF564CB7858L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xC6198C4C9CC49E17L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x4E21BEF7164D4BBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0xBCF5FAEDEE1F2B5AL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0x6F680C877A358FE5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); } } @@ -128,13 +144,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. @@ -151,13 +167,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 2d94b66a1e12..9a89e6290e69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -40,7 +40,7 @@ object HashBenchmark { safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() ).toArray - val benchmark = new Benchmark("Hash For " + name, iters * numRows) + val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong) benchmark.addCase("interpreted version") { _: Int => var sum = 0 for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 2a753a0c84ed..f6c8111f5bc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -36,7 +36,8 @@ object HashByteArrayBenchmark { bytes } - val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + val benchmark = + new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong) benchmark.addCase("Murmur3_x86_32") { _: Int => var sum = 0L for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 769addf3b29e..6c6376994531 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark { val iters = 1024 * 16 val numRows = 1024 * 16 - val benchmark = new Benchmark("unsafe projection", iters * numRows) + val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong) val schema1 = new StructType().add("l", LongType, false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala deleted file mode 100644 index 5b802ccc637d..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} - -class PartitioningSuite extends SparkFunSuite { - test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { - val expressions = Seq(Literal(2), Literal(3)) - // Consider two HashPartitionings that have the same _set_ of hash expressions but which are - // created with different orderings of those expressions: - val partitioningA = HashPartitioning(expressions, 100) - val partitioningB = HashPartitioning(expressions.reverse, 100) - // These partitionings are not considered equal: - assert(partitioningA != partitioningB) - // However, they both satisfy the same clustered distribution: - val distribution = ClusteredDistribution(expressions) - assert(partitioningA.satisfies(distribution)) - assert(partitioningB.satisfies(distribution)) - // These partitionings compute different hashcodes for the same input row: - def computeHashCode(partitioning: HashPartitioning): Int = { - val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) - hashExprProj.apply(InternalRow.empty).hashCode() - } - assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) - // Thus, these partitionings are incompatible: - assert(!partitioningA.compatibleWith(partitioningB)) - assert(!partitioningB.compatibleWith(partitioningA)) - assert(!partitioningA.guarantees(partitioningB)) - assert(!partitioningB.guarantees(partitioningA)) - - // Just to be sure that we haven't cheated by having these methods always return false, - // check that identical partitionings are still compatible with and guarantee each other: - assert(partitioningA === partitioningA) - assert(partitioningA.guarantees(partitioningA)) - assert(partitioningA.compatibleWith(partitioningA)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 23e866cdf491..353b8344658f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -356,4 +356,23 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } + + test("SPARK-23025: schemaFor should support Null type") { + val schema = schemaFor[(Int, Null)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", NullType, nullable = true))), + nullable = true)) + } + + test("SPARK-23835: add null check to non-nullable types in Tuples") { + def numberOfCheckedArguments(deserializer: Expression): Int = { + assert(deserializer.isInstanceOf[NewInstance]) + deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + } + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f4514205d3ae..cd8579584ead 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) - assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) - assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6)) assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 549a4355dfba..3d7c91870133 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -54,8 +54,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.execute(inputPlan) - analyzer.checkAnalysis(actualPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8..c86dc18dfa68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala new file mode 100644 index 000000000000..fe57c199b874 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} + +/** + * Test suite for resolving Uuid expressions. + */ +class ResolvedUuidExpressionsSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val r = LocalRelation(a) + private lazy val uuid1 = Uuid().as('_uuid1) + private lazy val uuid2 = Uuid().as('_uuid2) + private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1Ref = uuid1.toAttribute + + private val analyzer = getAnalyzer(caseSensitive = true) + + private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { + plan.flatMap { + case p => + p.expressions.flatMap(_.collect { + case u: Uuid => u + }) + } + } + + test("analyzed plan sets random seed for Uuid expression") { + val plan = r.select(a, uuid1) + val resolvedPlan = analyzer.executeAndCheck(plan) + getUuidExpressions(resolvedPlan).foreach { u => + assert(u.resolved) + assert(u.randomSeed.isDefined) + } + } + + test("Uuid expressions should have different random seeds") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan = analyzer.executeAndCheck(plan) + assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) + } + + test("Different analyzed plans should have different random seeds in Uuids") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan1 = analyzer.executeAndCheck(plan) + val resolvedPlan2 = analyzer.executeAndCheck(plan) + val uuids1 = getUuidExpressions(resolvedPlan1) + val uuids2 = getUuidExpressions(resolvedPlan2) + assert(uuids1.distinct.length == 3) + assert(uuids2.distinct.length == 3) + assert(uuids1.intersect(uuids2).length == 0) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622..0acd3b490447 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -429,6 +429,24 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { @@ -506,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -518,11 +536,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -539,6 +557,9 @@ class TypeCoercionSuite extends AnalysisTest { val floatLit = Literal.create(1.0f, FloatType) val timestampLit = Literal.create("2017-04-12", TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) + val strArrayLit = Literal(Array("c")) + val intArrayLit = Literal(Array(1)) ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), @@ -572,6 +593,16 @@ class TypeCoercionSuite extends AnalysisTest { Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, intLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), + Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), + Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), + Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) } test("CreateArray casts") { @@ -792,7 +823,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -923,6 +954,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. @@ -1153,7 +1238,7 @@ class TypeCoercionSuite extends AnalysisTest { */ test("make sure rules do not fire early") { // InConversion - val inConversion = TypeCoercion.InConversion + val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, In(UnresolvedAttribute("a"), Seq(Literal(1))), In(UnresolvedAttribute("a"), Seq(Literal(1))) @@ -1190,25 +1275,47 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) } test("binary comparison with string promotion") { - ruleTest(PromoteStrings, + val rule = TypeCoercion.PromoteStrings(conf) + ruleTest(rule, GreaterThan(Literal("123"), Literal(1)), GreaterThan(Cast(Literal("123"), IntegerType), Literal(1))) - ruleTest(PromoteStrings, + ruleTest(rule, LessThan(Literal(true), Literal("123")), LessThan(Literal(true), Cast(Literal("123"), BooleanType))) - ruleTest(PromoteStrings, + ruleTest(rule, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) - ruleTest(PromoteStrings, + ruleTest(rule, GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), - GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), + DoubleType))) + Seq(true, false).foreach { convertToTS => + withSQLConf( + "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) { + val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) + val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) + val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) + if (convertToTS) { + // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549 + ruleTest(rule, EqualTo(date0301, timestamp0301000000), + EqualTo(Cast(date0301, TimestampType), timestamp0301000000)) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, TimestampType), timestamp0301000001)) + } else { + ruleTest(rule, LessThan(date0301, timestamp0301000000), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType))) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType))) + } + } + } } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda26..cb487c889354 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a07..2fcaeca34db3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 95c87ffa20cb..6abab0073cca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -279,7 +279,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } - test("create temp table") { + test("create temp view") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) val tempTable2 = Range(1, 20, 2, 10) @@ -288,11 +288,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { assert(catalog.getTempView("tbl1") == Option(tempTable1)) assert(catalog.getTempView("tbl2") == Option(tempTable2)) assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists + // Temporary view already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } - // Temporary table already exists but we override it + // Temporary view already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) assert(catalog.getTempView("tbl1") == Option(tempTable2)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a..5b25bdf907c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,73 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } + + test("SPARK-22981 Cast struct to string") { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, "[1, a, 0.1]") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, "[1,, a]") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, "[[1, a], 5, 0.1]") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956ddc..5b71becee2de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { - ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;") } assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext @@ -436,4 +436,67 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addImmutableStateIfNotExists("String", mutableState2) assert(ctx.inlinedMutableStates.length == 2) } + + test("SPARK-23628: calculateParamLength should compute properly the param length") { + assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101) + assert(CodeGenerator.calculateParamLength( + Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) + } + + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } + + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + } + } + + test("SPARK-23986: freshName can generate duplicated names") { + val ctx = new CodegenContext + val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: + ctx.freshName("myName11") :: Nil + assert(names1.distinct.length == 3) + val names2 = ctx.freshName("a") :: ctx.freshName("a") :: + ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil + assert(names2.distinct.length == 4) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a2..a2851d071c7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -61,28 +61,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { @@ -105,4 +135,337 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + + test("Array Min") { + checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) + checkEvaluation( + ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "") + checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234) + } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } + + test("Reverse") { + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai7 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2)) + checkEvaluation(Reverse(ai1), Seq(3, 1, 2)) + checkEvaluation(Reverse(ai2), Seq(3, null, 1, null)) + checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2)) + checkEvaluation(Reverse(ai4), Seq(null, null, null)) + checkEvaluation(Reverse(ai5), Seq(1)) + checkEvaluation(Reverse(ai6), Seq.empty) + checkEvaluation(Reverse(ai7), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType)) + val as7 = Literal.create(null, ArrayType(StringType)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b")) + checkEvaluation(Reverse(as1), Seq("c", "a", "b")) + checkEvaluation(Reverse(as2), Seq("c", null, "a", null)) + checkEvaluation(Reverse(as3), Seq(null, "d", null, "b")) + checkEvaluation(Reverse(as4), Seq(null, null, null)) + checkEvaluation(Reverse(as5), Seq("a")) + checkEvaluation(Reverse(as6), Seq.empty) + checkEvaluation(Reverse(as7), null) + checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) + } + + test("Array Position") { + val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) + checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) + checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) + checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) + checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayPosition(a3, Literal("")), null) + checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } + + test("Concat") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val ai4 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5)) + checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5)) + checkEvaluation(Concat(Seq(ai4)), null) + checkEvaluation(Concat(Seq(ai0, ai4)), null) + checkEvaluation(Concat(Seq(ai4, ai0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) + val as4 = Literal.create(null, ArrayType(StringType)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + + checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e")) + checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e")) + checkEvaluation(Concat(Seq(as4)), null) + checkEvaluation(Concat(Seq(as0, as4)), null) + checkEvaluation(Concat(Seq(as4, as0)), null) + + checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 84190f0bd5f7..b4138ce366b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { null, null) } intercept[RuntimeException] { - checkEvalutionWithUnsafeProjection( + checkEvaluationWithUnsafeProjection( CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 63f6ceeb21b9..63b24fb9eb13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -210,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) } + test("WeekDay") { + checkEvaluation(WeekDay(Literal.create(null, DateType)), null) + checkEvaluation(WeekDay(Literal(d)), 2) + checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) @@ -452,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } @@ -791,6 +816,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate( + ToUTCTimestamp(Literal(Timestamp.valueOf("2015-07-24 00:00:00")), Literal("\"quote")) :: Nil) } test("from_utc_timestamp") { @@ -811,5 +839,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate(FromUTCTimestamp(Literal(0), Literal("\"quote")) :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5c..a22e9d4655e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException @@ -24,6 +26,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -44,25 +47,32 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } - protected def checkEvaluation( - expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + } + + protected def checkEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) + checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], and MapData. + * Array[Byte], Spread[Double], MapData and Row. */ - protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) @@ -88,12 +98,45 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } } - protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + inputRow: InternalRow, + expectedErrMsg: String): Unit = { + + def checkException(eval: => Unit, testMode: String): Unit = { + withClue(s"($testMode)") { + val errMsg = intercept[T] { + eval + }.getMessage + if (!errMsg.contains(expectedErrMsg)) { + fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") + } + } + } + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) + checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") + checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode") + } + } + + protected def evaluateWithoutCodegen( + expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { case n: Nondeterministic => n.initialize(0) case _ => @@ -122,7 +165,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!checkResult(actual, expected, expression.dataType)) { @@ -137,33 +180,38 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = evaluateWithGeneratedMutableProjection(expression, inputRow) + if (!checkResult(actual, expected, expression.dataType)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + protected def evaluateWithGeneratedMutableProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected, expression.dataType)) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") - } + plan(inputRow).get(0, expression.dataType) } - protected def checkEvalutionWithUnsafeProjection( + protected def checkEvaluationWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // SPARK-16489 Explicitly doing code generation twice so code gen will fail if - // some expression is reusing variable names across different instances. - // This behavior is tested in ExpressionEvalHelperSuite. - val plan = generateProject( - UnsafeProjection.create( - Alias(expression, s"Optimized($expression)1")() :: - Alias(expression, s"Optimized($expression)2")() :: Nil), - expression) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) + } - val unsafeRow = plan(inputRow) + protected def checkEvaluationWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow, + factory: UnsafeProjectionCreator): Unit = { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -175,7 +223,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } else { val lit = InternalRow(expected, expected) val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + factory.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -183,6 +231,23 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def evaluateWithUnsafeProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow, + factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + // SPARK-16489 Explicitly doing code generation twice so code gen will fail if + // some expression is reusing variable names across different instances. + // This behavior is tested in ExpressionEvalHelperSuite. + val plan = generateProject( + UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil), + expression) + + plan.initialize(0) + plan(inputRow) + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, @@ -292,7 +357,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try { - evaluate(expr, inputRow) + evaluateWithoutCodegen(expr, inputRow) } catch { case e: Exception => fail(s"Exception evaluating $expr", e) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index a0bbe02f9235..00e97637eee7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,11 +22,13 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { val json = """ |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}], @@ -390,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), InternalRow(1) ) } @@ -399,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), null ) } @@ -414,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), null ) } @@ -477,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), null) } @@ -489,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), InternalRow(c.getTimeInMillis * 1000L) ) @@ -510,7 +512,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID)), + Option(tz.getID), + true), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -519,7 +522,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId), + gmtId, + true), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -528,7 +532,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), null ) } @@ -680,4 +684,26 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } } + + test("from_json missing fields") { + for (forceJsonNullableSchema <- Seq(false, true)) { + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs( + jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 39e0060d41dd..3a094079380f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -124,7 +124,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkNaNWithoutCodegen( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!actual.asInstanceOf[Double].isNaN) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index facc86308130..b6c269348b00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.PrintStream + +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -40,7 +44,49 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("uuid") { - checkEvaluation(Length(Uuid()), 36) - assert(evaluate(Uuid()) !== evaluate(Uuid())) + checkEvaluation(Length(Uuid(Some(0))), 36) + val r = new Random() + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) === + evaluateWithGeneratedMutableProjection(Uuid(seed1))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) === + evaluateWithUnsafeProjection(Uuid(seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !== + evaluateWithGeneratedMutableProjection(Uuid(seed2))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== + evaluateWithUnsafeProjection(Uuid(seed2))) + + val uuid = Uuid(seed1) + assert(uuid.fastEquals(uuid)) + assert(!uuid.fastEquals(Uuid(seed1))) + assert(!uuid.fastEquals(uuid.freshCopy())) + assert(!uuid.fastEquals(Uuid(seed2))) + } + + test("PrintToStderr") { + val inputExpr = Literal(1) + val systemErr = System.err + + val (outputEval, outputCodegen) = try { + val errorStream = new java.io.ByteArrayOutputStream() + System.setErr(new PrintStream(errorStream)) + // check without codegen + checkEvaluationWithoutCodegen(PrintToStderr(inputExpr), 1) + val outputEval = errorStream.toString + errorStream.reset() + // check with codegen + checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1) + val outputCodegen = errorStream.toString + (outputEval, outputCodegen) + } finally { + System.setErr(systemErr) + } + + assert(outputCodegen.contains(s"Result of $inputExpr is 1")) + assert(outputEval.contains(s"Result of $inputExpr is 1")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index cc6c15cb2c90..424c3a469607 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -51,7 +51,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 3edcc02f1526..77ca640f2e0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,13 +17,52 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{IntegerType, ObjectType} +import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag +import scala.util.Random + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +class InvokeTargetClass extends Serializable { + def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 + def filterPrimitiveInt(e: Int): Boolean = e > 0 + def binOp(e1: Int, e2: Double): Double = e1 + e2 +} + +class InvokeTargetSubClass extends InvokeTargetClass { + override def binOp(e1: Int, e2: Double): Double = e1 - e2 +} + +// Tests for NewInstance +class Outer extends Serializable { + class Inner(val value: Int) { + override def hashCode(): Int = super.hashCode() + override def equals(other: Any): Boolean = { + if (other.isInstanceOf[Inner]) { + value == other.asInstanceOf[Inner].value + } else { + false + } + } + } +} class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -41,16 +80,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) - checkEvalutionWithUnsafeProjection( - structEncoder.serializer.head, structExpected, structInputRow) + checkEvaluationWithUnsafeProjection( + structEncoder.serializer.head, + structExpected, + structInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) - checkEvalutionWithUnsafeProjection( - arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + checkEvaluationWithUnsafeProjection( + arrayEncoder.serializer.head, + arrayExpected, + arrayInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -63,7 +108,511 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new ArrayBasedMapData( new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) - checkEvalutionWithUnsafeProjection( - mapEncoder.serializer.head, mapExpected, mapInputRow) + checkEvaluationWithUnsafeProjection( + mapEncoder.serializer.head, + mapExpected, + mapInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + } + + test("SPARK-23582: StaticInvoke should support interpreted execution") { + Seq((classOf[java.lang.Boolean], "true", true), + (classOf[java.lang.Byte], "1", 1.toByte), + (classOf[java.lang.Short], "257", 257.toShort), + (classOf[java.lang.Integer], "12345", 12345), + (classOf[java.lang.Long], "12345678", 12345678.toLong), + (classOf[java.lang.Float], "12.34", 12.34.toFloat), + (classOf[java.lang.Double], "1.2345678", 1.2345678) + ).foreach { case (cls, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))), + expected, InternalRow.fromSeq(Seq(arg))) + } + + // Return null when null argument is passed with propagateNull = true + val stringCls = classOf[java.lang.String] + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true), + null, InternalRow.fromSeq(Seq(null))) + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false), + "null", InternalRow.fromSeq(Seq(null))) + + // test no argument + val clCls = classOf[java.lang.ClassLoader] + checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil), + ClassLoader.getSystemClassLoader, InternalRow.empty) + // test more than one argument + val intCls = classOf[java.lang.Integer] + checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare", + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))), + 0, InternalRow.fromSeq(Seq(7, 7))) + + Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]), + new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))), + (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]), + new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))), + (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]), + "abc", UTF8String.fromString("abc")), + (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]), + BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))), + (Decimal.getClass, DecimalType.SYSTEM_DEFAULT, + "apply", ObjectType(classOf[java.math.BigInteger]), + new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))), + (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]), + Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))), + (classOf[UnsafeArrayData], ArrayType(IntegerType, false), + "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), + Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), + (DateTimeUtils.getClass, ObjectType(classOf[Date]), + "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777, + DateTimeUtils.toJavaDate(77777)), + (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), + "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]), + 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) + ).foreach { case (cls, dataType, methodName, argType, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, + Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) + } + } + + test("SPARK-23583: Invoke should support interpreted execution") { + val targetObject = new InvokeTargetClass + val funcClass = classOf[InvokeTargetClass] + val funcObj = Literal.create(targetObject, ObjectType(funcClass)) + val targetSubObject = new InvokeTargetSubClass + val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass])) + val funcNullObj = Literal.create(null, ObjectType(funcClass)) + + val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true)) + val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false)) + val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false)) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt), + false, InternalRow.fromSeq(Seq(-1))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(null))) + + checkObjectExprEvaluation( + Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25)) + + checkObjectExprEvaluation( + Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) + } + + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("InitializeJavaBean doesn't call setters if input in null") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + + test("SPARK-23585: UnwrapOption should support interpreted execution") { + val cls = classOf[Option[Int]] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val unwrapObject = UnwrapOption(IntegerType, inputObject) + Seq((Some(1), 1), (None, null), (null, null)).foreach { case (input, expected) => + checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } + + test("SPARK-23586: WrapOption should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val wrapObject = WrapOption(inputObject, cls) + Seq((1, Some(1)), (null, None)).foreach { case (input, expected) => + checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } + + test("SPARK-23590: CreateExternalRow should support interpreted execution") { + val schema = new StructType().add("a", IntegerType).add("b", StringType) + val createExternalRow = CreateExternalRow(Seq(Literal(1), Literal("x")), schema) + checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) + } + + // by scala values instead of catalyst values. + private def checkObjectExprEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + checkEvaluationWithoutCodegen(expr, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvaluationWithUnsafeProjection( + expr, + expected, + inputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + } + checkEvaluationWithOptimization(expr, expected, inputRow) + } + + test("SPARK-23594 GetExternalRowField should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") + Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + } + + // If an input row or a field are null, a runtime exception will be thrown + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(null)), + "The input external row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(Row(null))), + "The 0th field 'c0' of input row cannot be null.") + } + + test("SPARK-23591: EncodeUsingSerializer should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val expected = serializer.newInstance().serialize(new Integer(1)).array() + val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) + checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) + checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } + + test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) + + // Unsupported custom collection class + val errMsg = intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) + } + } + + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { + val cls = classOf[java.lang.Integer] + val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val input = serializer.newInstance().serialize(new Integer(1)).array() + val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) + checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } + + test("SPARK-23584 NewInstance should support interpreted execution") { + // Normal case test + val newInst1 = NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None) + checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3))) + + // Inner class case test + val outerObj = new Outer() + val newInst2 = NewInstance( + cls = classOf[outerObj.Inner], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[outerObj.Inner]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) } + + test("LambdaVariable should support interpreted execution") { + def genSchema(dt: DataType): Seq[StructType] = { + Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), + StructType(StructField("col_1", dt, nullable = true) :: Nil)) + } + + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val mapTypes = elementTypes.flatMap { elementType => + Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true)) + } + val structTypes = elementTypes.flatMap { elementType => + Seq(StructType(StructField("col1", elementType, false) :: Nil), + StructType(StructField("col1", elementType, true) :: Nil)) + } + + val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes + val random = new Random(100) + testTypes.foreach { dt => + genSchema(dt).map { schema => + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) + } + } + } + + implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]() + + test("SPARK-23588 CatalystToExternalMap should support interpreted execution") { + // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan + // with dummy input, resolve the plan by the analyzer, and replace the dummy input + // with a literal for tests. + val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) + val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) + + val analyzedPlan = SimpleAnalyzer.execute(plan) + val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head + + // Replaces the dummy input with a literal for tests here + val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val deserializer = toMapExpr.copy(inputData = Literal.create(data)) + checkObjectExprEvaluation(deserializer, expected = data) + } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } +} + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 8a8f8e10225f..ac76b17ef476 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -442,4 +442,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") { + checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false) + checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) + } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2a0a42c65b08..d532dc4f7719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -100,12 +100,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // invalid escaping val invalidEscape = intercept[AnalysisException] { - evaluate("""a""" like """\a""") + evaluateWithoutCodegen("""a""" like """\a""") } assert(invalidEscape.getMessage.contains("pattern")) val endEscape = intercept[AnalysisException] { - evaluate("""a""" like """a\""") + evaluateWithoutCodegen("""a""" like """a\""") } assert(endEscape.getMessage.contains("pattern")) @@ -147,11 +147,11 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") + evaluateWithoutCodegen("abbbbc" rlike "**") } intercept[java.util.regex.PatternSyntaxException] { val regex = 'a.string.at(0) - evaluate("abbbbc" rlike regex, create_row("**")) + evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 10e3ffd0dff9..e083ae008924 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e1.getMessage.contains("Failed to execute user defined function")) val e2 = intercept[SparkException] { - checkEvalutionWithUnsafeProjection(udf, null) + checkEvaluationWithUnsafeProjection(udf, null) } assert(e2.getMessage.contains("Failed to execute user defined function")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e2..f1a6f9b8889f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { @@ -617,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("REVERSE") { val s = 'a.string.at(0) val row1 = create_row("abccc") - checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) - checkEvaluation(StringReverse(s), "cccba", row1) - checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) + checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) + checkEvaluation(Reverse(s), "cccba", row1) + checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { @@ -744,7 +756,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { - evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), + evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), Literal("QUERY"), Literal("???")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index d6c8fcf29184..351d4d0c2eac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -27,7 +27,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva test("time window is unevaluable") { intercept[UnsupportedOperationException] { - evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + evaluateWithoutCodegen(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index cf3cbe270753..c07da122cd7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - test("basic conversion with only primitive types") { - val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = UnsafeProjection.create(fieldTypes) + private def testWithFactory( + name: String)( + f: UnsafeProjectionCreator => Unit): Unit = { + test(name) { + f(UnsafeProjection) + f(InterpretedUnsafeProjection) + } + } + testWithFactory("basic conversion with only primitive types") { factory => + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) @@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - test("basic conversion with primitive, string and binary types") { + testWithFactory("basic conversion with primitive, string and binary types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - test("basic conversion with primitive, string, date and timestamp types") { + testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - test("null handling") { + testWithFactory("null handling") { factory => val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificInternalRow(fieldTypes) @@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - test("NaN canonicalization") { + testWithFactory("NaN canonicalization") { factory => val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - test("basic conversion with struct type") { + testWithFactory("basic conversion with struct type") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) @@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - test("basic conversion with array type") { + testWithFactory("basic conversion with array type") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) @@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - test("basic conversion with map type") { + testWithFactory("basic conversion with map type") { factory => val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val map1 = createMap(1, 2)(3, 4) @@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - test("basic conversion with struct and array") { + testWithFactory("basic conversion with struct and array") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) @@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with struct and map") { + testWithFactory("basic conversion with struct and map") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) @@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with array and map") { + testWithFactory("basic conversion with array and map") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala similarity index 58% rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index 7aa7dd096c07..378b8bc055e3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -15,20 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.kafka010 +package org.apache.spark.sql.catalyst.expressions.codegen -import org.scalatest.PrivateMethodTester +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.BooleanType -import org.apache.spark.sql.test.SharedSQLContext +class ExprValueSuite extends SparkFunSuite { -class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral - test("SPARK-19886: Report error cause correctly in reportDataLoss") { - val cause = new Exception("D'oh!") - val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) - val e = intercept[IllegalStateException] { - CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) - } - assert(e.getCause === cause) + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + assert(trueLit.isPrimitive) + assert(falseLit.isPrimitive) + + assert(trueLit === JavaCode.literal("true", BooleanType)) + assert(falseLit === JavaCode.literal("false", BooleanType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index f203f25ad10d..75c6beeb3215 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for [[GenerateUnsafeRowJoiner]]. @@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { testConcat(64, 64, fixed) } + test("rows with all empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8)) + testConcat(schema, row, schema, row) + } + + test("rows with all empty int arrays") { + val schema = StructType(Seq( + StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType)))) + val emptyIntArray = + ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(emptyIntArray, emptyIntArray)) + testConcat(schema, row, schema, row) + } + + test("alternating empty and non-empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo"))) + testConcat(schema, row, schema, row) + } + test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) @@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) + testConcat(schema1, row1, schema2, row2) + } + + private def testConcat( + schema1: StructType, + row1: UnsafeRow, + schema2: StructType, + row2: UnsafeRow) { // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) - val output = concater.join(row1, row2) + val output: UnsafeRow = concater.join(row1, row2) + + // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures + // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written + // correctly. + val expectedOutput: UnsafeRow = { + val joinedRowProjection = UnsafeProjection.create(mergedSchema) + val joined = new JoinedRow() + joinedRowProjection.apply(joined.apply(row1, row2)) + } // Test everything equals ... for (i <- mergedSchema.indices) { + val dataType = mergedSchema(i).dataType if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row1.get(i, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === - row2.get(i - schema1.size, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } } + + + assert( + expectedOutput.getSizeInBytes == output.getSizeInBytes, + "output isn't same size in bytes as slow path") + + // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in + // case this assertion fails: + val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]] + .take(output.getSizeInBytes) + val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]] + .take(expectedOutput.getSizeInBytes) + + val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields()) + val actualBitset = actualBytes.take(bitsetWidth) + val expectedBitset = expectedBytes.take(bitsetWidth) + assert(actualBitset === expectedBitset, "bitsets were not equal") + + val fixedLengthSize = expectedOutput.numFields() * 8 + val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + if (actualFixedLength !== expectedFixedLength) { + actualFixedLength.grouped(8) + .zip(expectedFixedLength.grouped(8)) + .zip(mergedSchema.fields.toIterator) + .foreach { + case ((actual, expected), field) => + assert(actual === expected, s"Fixed length sections are not equal for field $field") + } + fail("Fixed length sections were not equal") + } + + val variableLengthStart = bitsetWidth + fixedLengthSize + val actualVariableLength = actualBytes.drop(variableLengthStart) + val expectedVariableLength = expectedBytes.drop(variableLengthStart) + assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal") + + assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala new file mode 100644 index 000000000000..1b25a4b191f8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class UTF8StringBuilderSuite extends SparkFunSuite { + + test("basic test") { + val sb = new UTF8StringBuilder() + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("") + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("abcd") + assert(sb.build() === UTF8String.fromString("abcd")) + + sb.append(UTF8String.fromString("1234")) + assert(sb.build() === UTF8String.fromString("abcd1234")) + + // expect to grow an internal buffer + sb.append(UTF8String.fromString("efgijk567890")) + assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index a3184a4266c7..f8ddc9359707 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -67,10 +67,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("remove repetition in grouping expression") { - val input = LocalRelation('a.int, 'b.int, 'c.int) - val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze + val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e..788fedb3c8e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 85a5e979f602..82a10254d846 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -809,6 +809,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("aggregate: don't push filters if the aggregate has no grouping expressions") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy()(count(1)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a..c28844642aed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec7..e4671f0d1cce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -34,11 +34,26 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushDownPredicate, InferFiltersFromConstraints, CombineFilters, - BooleanSimplification) :: Nil + SimplifyBinaryComparison, + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -160,64 +175,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join with alias: don't generate constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through - // the constraint inference procedure. - val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - // We hide an `Alias` inside the child's child's expressions, to cover the situation reported - // in [SPARK-20700]. - .select('int_col, 'd, 'a).as("t") - .join(t2, Inner, - Some("t.a".attr === "t2.a".attr - && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a))) - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b))) - && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b)) - && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b)) - && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b))) - .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - .select('int_col, 'd, 'a).as("t") - .join( - t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && - 'a === Coalesce(Seq('a, 'a))), - Inner, - Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - - test("inner join with EqualTo expressions containing part of each other: don't generate " + - "constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating - // complicated constraints through the constraint inference procedure. - val originalQuery = t1 - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .where('a === 'd && 'c === 'e) - .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) && - 'c === Coalesce(Seq('a, 'b))) - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .join(t2.where(IsNotNull('a) && IsNotNull('c)), - Inner, - Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - test("generate correct filters for alias that don't produce recursive constraints") { val t1 = testRelation.subquery('t1) @@ -236,4 +193,74 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } + + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) + } + + test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze + val left = x.where(IsNotNull('a) && 'a === 2) + val right = y.where(IsNotNull('a) && 'a === 2) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze + val left = x.where(IsNotNull('a) && 'a > 5) + val right = y.where(IsNotNull('a) && 'a > 5) + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join no filter push down to preserved side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 2fb587d50a4c..565b0a10154a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -62,24 +62,15 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } - /** Set up tables and columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("t1.k-1-2") -> rangeColumnStat(2, 0), + attr("t1.v-1-10") -> rangeColumnStat(10, 0), + attr("t2.k-1-5") -> rangeColumnStat(5, 0), + attr("t3.v-1-100") -> rangeColumnStat(100, 0), + attr("t4.k-1-2") -> rangeColumnStat(2, 0), + attr("t4.v-1-10") -> rangeColumnStat(10, 0), + attr("t5.k-1-5") -> rangeColumnStat(5, 0), + attr("t5.v-1-5") -> rangeColumnStat(5, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index cc98d2350c77..17fb9fc5d11e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -93,7 +93,21 @@ class LimitPushdownSuite extends PlanTest { test("left outer join") { val originalQuery = x.join(y, LeftOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(Limit(2, y), LeftOuter)).analyze comparePlans(optimized, correctAnswer) } @@ -104,6 +118,20 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("right outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + test("larger limits are not pushed on top of smaller ones in right outer join") { val originalQuery = x.join(y.limit(5), RightOuter).limit(10) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 3964508e3a55..f1ce7543ffdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceIntersectWithSemiJoin, PushDownPredicate, PruneFilters, - PropagateEmptyRelation) :: Nil + PropagateEmptyRelation, + CollapseProject) :: Nil } object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters, + CollapseProject) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) @@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, FullOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), (true, false, LeftSemi, Some(LocalRelation('a.int))), @@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, - Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("propagate empty relation keeps the plan resolved") { + val query = testRelation1.join( + LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + val optimized = Optimize.execute(query.analyze) + assert(optimized.resolved) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 000000000000..dae5e6f3ee3d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class RemoveRedundantSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.limit(2).select('a).analyze + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze + comparePlans(optimized, correctAnswer) + } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 0fa1aaeb9e16..52dc2e9fb076 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -168,6 +168,21 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter when only right filter can be applied to the left") { + val table = LocalRelation(Seq('a.int, 'b.int)) + val left = table.where('b < 1).select('a).as("left") + val right = table.where('b < 3).select('a).as("right") + + val query = Except(left, right) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(left.output, right.output, + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) @@ -198,6 +213,14 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("add one grouping key if necessary when replace Deduplicate with Aggregate") { + val input = LocalRelation() + val query = Deduplicate(Seq.empty, input) // dropDuplicates() + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input) + comparePlans(optimized, correctAnswer) + } + test("don't replace streaming Deduplicate") { val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) val attrA = input.output(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index ada6e2a43ea0..d4d23ad69b2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -68,88 +68,56 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 (fact table) - attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(100, 0), + attr("f1_fk2") -> rangeColumnStat(100, 0), + attr("f1_fk3") -> rangeColumnStat(100, 0), + attr("f1_c1") -> rangeColumnStat(100, 0), + attr("f1_c2") -> rangeColumnStat(100, 0), // D1 (dimension) - attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk") -> rangeColumnStat(100, 0), + attr("d1_c2") -> rangeColumnStat(50, 0), + attr("d1_c3") -> rangeColumnStat(50, 0), // D2 (dimension) - attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_pk") -> rangeColumnStat(20, 0), + attr("d2_c2") -> rangeColumnStat(10, 0), + attr("d2_c3") -> rangeColumnStat(10, 0), // D3 (dimension) - attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk") -> rangeColumnStat(10, 0), + attr("d3_c2") -> rangeColumnStat(5, 0), + attr("d3_c3") -> rangeColumnStat(5, 0), // T1 (regular table i.e. outside star) - attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c1") -> rangeColumnStat(20, 1), + attr("t1_c2") -> rangeColumnStat(10, 1), + attr("t1_c3") -> rangeColumnStat(10, 1), // T2 (regular table) - attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c1") -> rangeColumnStat(5, 1), + attr("t2_c2") -> rangeColumnStat(5, 1), + attr("t2_c3") -> rangeColumnStat(5, 1), // T3 (regular table) - attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c1") -> rangeColumnStat(5, 1), + attr("t3_c2") -> rangeColumnStat(5, 1), + attr("t3_c3") -> rangeColumnStat(5, 1), // T4 (regular table) - attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c1") -> rangeColumnStat(5, 1), + attr("t4_c2") -> rangeColumnStat(5, 1), + attr("t4_c3") -> rangeColumnStat(5, 1), // T5 (regular table) - attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c1") -> rangeColumnStat(5, 1), + attr("t5_c2") -> rangeColumnStat(5, 1), + attr("t5_c3") -> rangeColumnStat(5, 1), // T6 (regular table) - attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4) + attr("t6_c1") -> rangeColumnStat(5, 1), + attr("t6_c2") -> rangeColumnStat(5, 1), + attr("t6_c3") -> rangeColumnStat(5, 1) )) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 777c5637201e..4e0883e91e84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -70,59 +70,40 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // Tables' cardinality: f1 > d3 > d1 > d2 > s3 private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 - attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(3, 0), + attr("f1_fk2") -> rangeColumnStat(3, 0), + attr("f1_fk3") -> rangeColumnStat(4, 0), + attr("f1_c4") -> rangeColumnStat(4, 0), // D1 - attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk1") -> rangeColumnStat(4, 0), + attr("d1_c2") -> rangeColumnStat(3, 0), + attr("d1_c3") -> rangeColumnStat(4, 0), + attr("d1_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D2 - attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = Some(3), min = Some("1"), max = Some("3"), + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)), + attr("d2_pk1") -> rangeColumnStat(3, 0), + attr("d2_c3") -> rangeColumnStat(3, 0), + attr("d2_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D3 - attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_fk1") -> rangeColumnStat(3, 0), + attr("d3_c2") -> rangeColumnStat(3, 0), + attr("d3_pk1") -> rangeColumnStat(5, 0), + attr("d3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // S3 - attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_pk1") -> rangeColumnStat(2, 0), + attr("s3_c2") -> rangeColumnStat(1, 0), + attr("s3_c3") -> rangeColumnStat(1, 0), + attr("s3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // F11 - attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("f11_fk1") -> rangeColumnStat(3, 0), + attr("f11_fk2") -> rangeColumnStat(3, 0), + attr("f11_fk3") -> rangeColumnStat(4, 0), + attr("f11_c4") -> rangeColumnStat(4, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 56f096f3ecf8..5fc99a3a57c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,18 +39,19 @@ class TypedFilterOptimizationSuite extends PlanTest { implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + val testRelation = LocalRelation('_1.int, '_2.int) + test("filter after serialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -58,10 +60,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter after serialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze @@ -70,17 +71,16 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -89,10 +89,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze @@ -101,21 +100,89 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("back to back filter with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: (Int, Int)) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 1) } test("back to back filter with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: OtherTuple) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("back to back FilterFunction with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("back to back FilterFunction with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("FilterFunction and filter with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("FilterFunction and filter with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: OtherTuple) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("filter and FilterFunction with the same object type") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("filter and FilterFunction with different object types") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 000000000000..09b11f5aba2a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 0d11958876ce..633d86d49558 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -31,7 +32,7 @@ import org.apache.spark.sql.types._ * i.e. {{{create_named_struct(square, `x` * `x`).square}}} can be simplified to {{{`x` * `x`}}}. * sam applies to create_array and create_map */ -class ComplexTypesSuite extends PlanTest{ +class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = @@ -43,14 +44,20 @@ class ComplexTypesSuite extends PlanTest{ BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: Nil + SimplifyExtractValueOps) :: Nil } - val idAtt = ('id).long.notNull + private val idAtt = ('id).long.notNull + private val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt ) + private val relation = LocalRelation(idAtt, nullableIdAtt) + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + + private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { + val optimized = Optimizer.execute(originalQuery.analyze) + assert(optimized.resolved, "optimized plans must be still resolvable") + comparePlans(optimized, correctAnswer.analyze) + } test("explicit get from namedStruct") { val query = relation @@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest{ GetStructField( CreateNamedStruct(Seq("att", 'id )), 0, - None) as "outerAtt").analyze - val expected = relation.select('id as "outerAtt").analyze + None) as "outerAtt") + val expected = relation.select('id as "outerAtt") - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) - .analyze val expected = relation .select('id as "named_struct(att, id).att") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") .select(GetStructField('struct1, 0, None) as "struct1Att") - .analyze - val expected = relation.select('id as "struct1Att").analyze - comparePlans(Optimizer execute query, expected) + val expected = relation.select('id as "struct1Att") + checkRule(query, expected) } test("collapse multiple CreateNamedStruct/GetStructField pairs") { @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest{ .select( GetStructField('struct1, 0, None) as "struct1Att1", GetStructField('struct1, 1, None) as "struct1Att2") - .analyze val expected = relation. select( 'id as "struct1Att1", ('id * 'id) as "struct1Att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed2 - deduced names") { @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest{ .select( GetStructField('struct1, 0, None), GetStructField('struct1, 1, None)) - .analyze val expected = relation. select( 'id as "struct1.att1", ('id * 'id) as "struct1.att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplified array ops") { @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest{ 1, false), 1) as "a4") - .analyze val expected = relation .select( @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest{ "att2", (('id + 1L) * ('id + 1L)))) as "a2", ('id + 1L) as "a3", ('id + 1L) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("SPARK-22570: CreateArray should not create a lot of global variables") { @@ -171,6 +169,11 @@ class ComplexTypesSuite extends PlanTest{ assert(ctx.inlinedMutableStates.length == 0) } + test("SPARK-23208: Test code splitting for create array related methods") { + val inputs = (1 to 2500).map(x => Literal(s"l_$x")) + checkEvaluation(CreateArray(inputs), new GenericArrayData(inputs.map(_.eval()))) + } + test("simplify map ops") { val rel = relation .select( @@ -183,7 +186,6 @@ class ComplexTypesSuite extends PlanTest{ GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", GetMapValue('m, "r32") as "a3", GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") - .analyze val expected = relation.select( @@ -196,8 +198,7 @@ class ComplexTypesSuite extends PlanTest{ ) ) as "a3", Literal.create(null, LongType) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, constant lookup, dynamic keys") { @@ -211,7 +212,6 @@ class ComplexTypesSuite extends PlanTest{ ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -220,8 +220,7 @@ class ComplexTypesSuite extends PlanTest{ (EqualTo(13L, ('id + 1L)), ('id + 2L)), (EqualTo(13L, ('id + 2L)), ('id + 3L)), (Literal(true), 'id))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { @@ -235,7 +234,6 @@ class ComplexTypesSuite extends PlanTest{ ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), ('id + 3L)) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( @@ -243,8 +241,7 @@ class ComplexTypesSuite extends PlanTest{ (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), (Literal(true), ('id + 4L)))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, no positive match") { @@ -258,7 +255,6 @@ class ComplexTypesSuite extends PlanTest{ ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 'id + 30L) as "a") - .analyze val expected = relation.select( CaseWhen(Seq( (EqualTo('id + 30L, 'id), ('id + 1L)), @@ -266,8 +262,7 @@ class ComplexTypesSuite extends PlanTest{ (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, constant lookup, mixed keys, eliminated constants") { @@ -282,7 +277,6 @@ class ComplexTypesSuite extends PlanTest{ ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -292,9 +286,8 @@ class ComplexTypesSuite extends PlanTest{ ('id + 2L), ('id + 3L), ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, potential dynamic match with null value + an absolute constant match") { @@ -309,20 +302,141 @@ class ComplexTypesSuite extends PlanTest{ ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 2L ) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( (EqualTo(2L, 'id), ('id + 1L)), - // these two are possible matches, we can't tell untill runtime + // these two are possible matches, we can't tell until runtime (EqualTo(2L, ('id + 1L)), ('id + 2L)), (EqualTo(2L, 'id + 2L), Literal.create(null, LongType)), // this is a definite match (two constants), // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. (Literal.TrueLiteral, 'id))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) + } + + test("SPARK-23500: Simplify array ops that are not at the top node") { + val query = LocalRelation('id.long) + .select( + CreateArray(Seq( + CreateNamedStruct(Seq( + "att1", 'id, + "att2", 'id * 'id)), + CreateNamedStruct(Seq( + "att1", 'id + 1, + "att2", ('id + 1) * ('id + 1)) + )) + ) as "arr") + .select( + GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetArrayItem( + GetArrayStructFields('arr, + StructField("att1", LongType, nullable = false), + ordinal = 0, + numFields = 1, + containsNull = false), + ordinal = 1) as "a2") + .orderBy('id.asc) + + val expected = LocalRelation('id.long) + .select( + ('id + 1L) as "a1", + ('id + 1L) as "a2") + .orderBy('id.asc) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify map ops that are not top nodes") { + val query = + LocalRelation('id.long) + .select( + CreateMap(Seq( + "r1", 'id, + "r2", 'id + 1L)) as "m") + .select( + GetMapValue('m, "r1") as "a1", + GetMapValue('m, "r32") as "a2") + .orderBy('id.asc) + .select('a1, 'a2) + + val expected = + LocalRelation('id.long).select( + 'id as "a1", + Literal.create(null, LongType) as "a2") + .orderBy('id.asc) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify complex ops that aren't at the plan root") { + val structRel = relation + .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") + .groupBy($"foo")("1") + val structExpected = relation + .select('nullable_id as "foo") + .groupBy($"foo")("1") + checkRule(structRel, structExpected) + + val arrayRel = relation + .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") + .groupBy($"a1")("1") + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + checkRule(arrayRel, arrayExpected) + + val mapRel = relation + .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") + .groupBy($"m1")("1") + val mapExpected = relation + .select('nullable_id as "m1") + .groupBy($"m1")("1") + checkRule(mapRel, mapExpected) + } + + test("SPARK-23500: Ensure that aggregation expressions are not simplified") { + // Make sure that aggregation exprs are correctly ignored. Maps can't be used in + // grouping exprs so aren't tested here. + val structAggRel = relation.groupBy( + CreateNamedStruct(Seq("att1", 'nullable_id)))( + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + checkRule(structAggRel, structAggRel) + + val arrayAggRel = relation.groupBy( + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + checkRule(arrayAggRel, arrayAggRel) + + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) + } + + test("SPARK-23500: namedStruct and getField in the same Project #1") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) + .select('s1 getField "col2" as 's1Col2, + namedStruct("col1", 'a, "col2", 'b).as("s2")) + .select('s1Col2, 's2 getField "col2" as 's2Col2) + val correctAnswer = + testRelation + .select('c as 's1Col2, 'b as 's2Col2) + checkRule(originalQuery, correctAnswer) + } + + test("SPARK-23500: namedStruct and getField in the same Project #2") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, + namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + val correctAnswer = + testRelation + .select('c as 'sCol2, 'a as 'sCol1) + checkRule(originalQuery, correctAnswer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2b9783a3295c..cb8a1fecb80a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -249,8 +249,8 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) @@ -263,21 +263,62 @@ class ExpressionParserSuite extends PlanTest { "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", WindowExpression('sum.function('product + 1), WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + } + + test("range/rows window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } - // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", -Literal(10), CurrentRow), + // No between combinations + ("unbounded preceding", UnboundedPreceding, CurrentRow), ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("10 preceding", -Literal(10), CurrentRow), + ("3 + 1 preceding", -Add(Literal(3), Literal(1)), CurrentRow), + ("0 preceding", -Literal(0), CurrentRow), + ("current row", CurrentRow, CurrentRow), + ("0 following", Literal(0), CurrentRow), ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), - ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("10 following", Literal(10), CurrentRow), + ("2147483649 following", Literal(2147483649L), CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + + // Between combinations + ("between unbounded preceding and 5 following", + UnboundedPreceding, Literal(5)), + ("between unbounded preceding and 3 + 1 following", + UnboundedPreceding, Add(Literal(3), Literal(1))), + ("between unbounded preceding and 2147483649 following", + UnboundedPreceding, Literal(2147483649L)), ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), + ("between 2147483648 preceding and current row", -Literal(2147483648L), CurrentRow), ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between 3 + 1 preceding and current row", -Add(Literal(3), Literal(1)), CurrentRow), + ("between 0 preceding and current row", -Literal(0), CurrentRow), + ("between current row and current row", CurrentRow, CurrentRow), + ("between current row and 0 following", CurrentRow, Literal(0)), ("between current row and 5 following", CurrentRow, Literal(5)), - ("between 10 preceding and 5 following", -Literal(10), Literal(5)) + ("between current row and 3 + 1 following", CurrentRow, Add(Literal(3), Literal(1))), + ("between current row and 2147483649 following", CurrentRow, Literal(2147483649L)), + ("between current row and unbounded following", CurrentRow, UnboundedFollowing), + ("between 2147483648 preceding and unbounded following", + -Literal(2147483648L), UnboundedFollowing), + ("between 10 preceding and unbounded following", + -Literal(10), UnboundedFollowing), + ("between 3 + 1 preceding and unbounded following", + -Add(Literal(3), Literal(1)), UnboundedFollowing), + ("between 0 preceding and unbounded following", -Literal(0), UnboundedFollowing), + + // Between partial and full range + ("between 10 preceding and 5 following", -Literal(10), Literal(5)), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing) ) frameTypes.foreach { case (frameTypeSql, frameType) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998..89903c282512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 866ff0d33cbb..a37e06d92264 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 82c5307d5436..6241d5cbb1d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -154,7 +154,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => } /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL * configurations. */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala new file mode 100644 index 000000000000..27914ef5565c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.IntegerType + +class QueryPlanSuite extends SparkFunSuite { + + test("origin remains the same after mapExpressions (SPARK-23823)") { + CurrentOrigin.setPosition(0, 0) + val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) + val query = plans.DslLogicalPlan(plans.table("table")).select(column) + CurrentOrigin.reset() + + val mappedQuery = query mapExpressions { + case _: Expression => Literal(1) + } + + val mappedOrigin = mappedQuery.expressions.apply(0).origin + assert(mappedOrigin == Origin.apply(Some(0), Some(0))) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 23f95a6cc2ac..8213d568fe85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -29,16 +29,16 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key11") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key12") -> ColumnStat(distinctCount = Some(4), min = Some(10), max = Some(40), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key21") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -63,8 +63,8 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { tableRowCount = 6, groupByColumns = Seq("key21", "key22"), // Row count = product of ndv - expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2 - .distinctCount) + expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get * + nameToColInfo("key22")._2.distinctCount.get) } test("empty group-by column") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 7d532ff34317..953094cb0dd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.IntegerType class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") - val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStat = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val plan = StatsTestPlan( outputList = Seq(attribute), @@ -116,13 +116,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2b1fe987a796..47bfa6256958 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() - val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() - val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1) + val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)) // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() - val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatDate = ColumnStat(distinctCount = Some(10), + min = Some(dMin), max = Some(dMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = Decimal("0.200000000000000000") val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDecimal = ColumnStat(distinctCount = Some(4), + min = Some(decMin), max = Some(decMax), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() - val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() - val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2) + val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)) // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint < cint2 val attrInt2 = AttributeReference("cint2", IntegerType)() - val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint = cint3 without overlap at all. val attrInt3 = AttributeReference("cint3", IntegerType)() - val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint4 has values in the range from 1 to 10 // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 // This column is created to test complete overlap val attrInt4 = AttributeReference("cint4", IntegerType)() - val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. // Note that cintHgm has an even distribution with histogram information built. @@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) - val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt)) // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. // Note that cintSkewHgm has a skewed distribution with histogram information built. @@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) - val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew)) val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, @@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 3)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } @@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } @@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 2)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))), expectedRowCount = 2) } @@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 6)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))), expectedRowCount = 6) } @@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 5)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -342,47 +344,70 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 9), - attrString -> colStatString.copy(distinctCount = 9)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)), + attrString -> colStatString.copy(distinctCount = Some(9))), expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 7)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } @@ -391,18 +416,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(1), + min = Some(d20170102), max = Some(d20170102), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { + val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170101), max = Some(d20170103), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -414,8 +442,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170103), max = Some(d20170105), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -424,42 +453,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(1), + min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 1) } test("cdecimal < 0.60 ") { + val dec_0_20 = Decimal("0.200000000000000000") val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(3), + min = Some(dec_0_20), max = Some(dec_0_60), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 1) } test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 10) } @@ -468,8 +500,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. // The predicate is: column IN (1, 2, 3, 4, 5). - val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildColStatInt = ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val cornerChildStatsTestplan = StatsTestPlan( outputList = Seq(attrInt), rowCount = 2L, @@ -477,16 +510,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) validateEstimatedStats( Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) } // This is a limitation test. We should remove it after the limitation is removed. test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { val attrIntLargerRange = AttributeReference("c1", IntegerType)() - val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 10, avgLen = 4, maxLen = 4) + val colStatIntLargerRange = ColumnStat(distinctCount = Some(20), + min = Some(1), max = Some(20), + nullCount = Some(10), avgLen = Some(4), maxLen = Some(4)) val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) val largerTable = StatsTestPlan( outputList = Seq(attrIntLargerRange), @@ -508,10 +542,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -519,10 +553,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -530,10 +564,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -541,10 +575,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // complete overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -552,10 +586,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -571,10 +605,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // all table records qualify. validateEstimatedStats( Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -592,11 +626,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), Seq( - attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - attrString -> colStatString.copy(distinctCount = 5)), + attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrString -> colStatString.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -606,15 +640,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cintHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 1) } @@ -629,8 +663,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } @@ -645,16 +679,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } test("cintHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -669,8 +703,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 5) } @@ -679,8 +713,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -688,7 +722,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -698,15 +732,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))), expectedRowCount = 9) } test("cintSkewHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 4) } @@ -721,8 +755,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintSkewHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -738,16 +772,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } test("cintSkewHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -764,8 +798,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 3) } @@ -774,8 +808,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 8) } @@ -783,7 +817,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))), expectedRowCount = 3) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 26139d85d25f..12c0a7be2129 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def estimateByHistogram( leftHistogram: Histogram, rightHistogram: Histogram, - expectedMin: Double, - expectedMax: Double, + expectedMin: Any, + expectedMax: Any, expectedNdv: Long, expectedRows: Long): Unit = { val col1 = attr("key1") @@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRows), attributeStats = AttributeMap(Seq( col1 -> c1.stats.attributeStats(col1).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)), col2 -> c2.stats.attributeStats(col2).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)))) ) // Join order should not affect estimation result. @@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def generateJoinChild( col: Attribute, histogram: Histogram, - expectedMin: Double, - expectedMax: Double): LogicalPlan = { - val colStat = inferColumnStat(histogram) + expectedMin: Any, + expectedMax: Any): LogicalPlan = { + val colStat = inferColumnStat(histogram, expectedMin, expectedMax) StatsTestPlan( outputList = Seq(col), rowCount = (histogram.height * histogram.bins.length).toLong, @@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } /** Column statistics should be consistent with histograms in tests. */ - private def inferColumnStat(histogram: Histogram): ColumnStat = { + private def inferColumnStat( + histogram: Histogram, + expectedMin: Any, + expectedMax: Any): ColumnStat = { + var ndv = 0L for (i <- histogram.bins.indices) { val bin = histogram.bins(i) @@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { ndv += bin.ndv } } - ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), - max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + ColumnStat(distinctCount = Some(ndv), + min = Some(expectedMin), max = Some(expectedMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(histogram)) } @@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(5 + 3), attributeStats = AttributeMap( // Update null count in column stats. - Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), - nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), - nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), - nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5))))) assert(join.stats == expectedStats) } @@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) // Update column stats for equi-join keys (key-1-5 and key-1-2). - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it // unchanged (key-2-4). - val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5)) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. - val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) - val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1), + min = Some(false), max = Some(false), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toByte), max = Some(1.toByte), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toShort), max = Some(1.toShort), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1), max = Some(1), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1L), max = Some(1L), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0), max = Some(1.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0f), max = Some(1.0f), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(1), min = Some(dec), max = Some(dec), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1), + min = Some(date), max = Some(date), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1), + min = Some(timestamp), max = Some(timestamp), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) ) } @@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("join with null column") { val (nullColumn, nullColStat) = (attr("cnull"), - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4))) val nullTable = StatsTestPlan( outputList = Seq(nullColumn), rowCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index cda54fa9d64f..dcb37017329f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.types._ class ProjectEstimationSuite extends StatsEstimationTestBase { test("project with alias") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1), - max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4)) - val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10), - max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1), + max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) + val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10), + max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1, ar2), @@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("project on empty table") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1), rowCount = 0, @@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2), + min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(6.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(7.0), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(2), min = Some(dec1), max = Some(dec2), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2), + min = Some(d1), max = Some(d2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2), + min = Some(t1), max = Some(t2), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) )) val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 31dea2e3e7f1..9dceca59f5b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen + case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4 + case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize) } def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() @@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite { val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) } + + /** Get a test ColumnStat with given distinctCount and nullCount */ + def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat = + ColumnStat(distinctCount = Some(distinctCount), + min = Some(1), max = Some(distinctCount), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala new file mode 100644 index 000000000000..6400898343ae --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.types._ + +class ArrayDataIndexedSeqSuite extends SparkFunSuite { + private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = { + assert(arrayData.numElements == array.length) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For NaN, etc. + case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e)) + case _ => assert(arrayData.get(i, elementDt) === e) + } + } else { + assert(arrayData.isNullAt(i)) + } + } + + val seq = arrayData.toSeq[Any](elementDt) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For Nan, etc. + case FloatType | DoubleType => assert(seq(i).equals(e)) + case _ => assert(seq(i) === e) + } + } else { + assert(seq(i) == null) + } + } + + intercept[IndexOutOfBoundsException] { + seq(-1) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + + intercept[IndexOutOfBoundsException] { + seq(seq.length) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + } + + private def testArrayData(): Unit = { + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val random = new Random(100) + arrayTypes.foreach { dt => + val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + + val unsafeRowConverter = UnsafeProjection.create(schema) + val safeRowConverter = FromUnsafeProjection(schema) + + val unsafeRow = unsafeRowConverter(internalRow) + val safeRow = safeRowConverter(unsafeRow) + + val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData] + val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + + val elementType = dt.elementType + test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) { + compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType)) + } + + test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) { + compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType)) + } + } + } + + testArrayData() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa..cbf6106697f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala new file mode 100644 index 000000000000..b75739e5a3a6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + +class RandomUUIDGeneratorSuite extends SparkFunSuite { + test("RandomUUIDGenerator should generate version 4, variant 2 UUIDs") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + for (_ <- 0 to 100) { + val uuid = generator.getNextUUID() + assert(uuid.version() == 4) + assert(uuid.variant() == 2) + } + } + + test("UUID from RandomUUIDGenerator should be deterministic") { + val r1 = new Random(100) + val generator1 = RandomUUIDGenerator(r1.nextLong()) + val r2 = new Random(100) + val generator2 = RandomUUIDGenerator(r2.nextLong()) + val r3 = new Random(101) + val generator3 = RandomUUIDGenerator(r3.nextLong()) + + for (_ <- 0 to 100) { + val uuid1 = generator1.getNextUUID() + val uuid2 = generator2.getNextUUID() + val uuid3 = generator3.getNextUUID() + assert(uuid1 == uuid2) + assert(uuid1 != uuid3) + } + } + + test("Get UTF8String UUID") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + val utf8StringUUID = generator.getNextUUIDUTF8String() + val uuid = java.util.UUID.fromString(utf8StringUUID.toString) + assert(uuid.version() == 4 && uuid.variant() == 2 && utf8StringUUID.toString == uuid.toString) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 8e2b32c2b9a0..5a86f4055dce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -134,6 +134,14 @@ class DataTypeSuite extends SparkFunSuite { assert(mapped === expected) } + test("fieldNames and names returns field names") { + val struct = StructType( + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) + + assert(struct.fieldNames === Seq("a", "b")) + assert(struct.names === Seq("a", "b")) + } + test("merge where right contains type conflict") { val left = StructType( StructField("a", LongType) :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala new file mode 100644 index 000000000000..c6ca8bb00542 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class StructTypeSuite extends SparkFunSuite { + + val s = StructType.fromDDL("a INT, b STRING") + + test("lookup a single missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s("c")).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup a set of missing fields should output existing fields") { + val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup fieldIndex for missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage + assert(e.contains("Available fields: a, b")) + } +} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 93010c606cf4..ef41837f89d6 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml @@ -195,7 +195,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 730a4ae8d560..74c9c0599271 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -62,10 +62,14 @@ public long durationMs() { */ public abstract void init(int index, Iterator[] iters); + /* + * Attributes of the following four methods are public. Thus, they can be also accessed from + * methods in inner classes. See SPARK-23598 + */ /** * Append a row to currentRows. */ - protected void append(InternalRow row) { + public void append(InternalRow row) { currentRows.add(row); } @@ -75,7 +79,7 @@ protected void append(InternalRow row) { * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. * This interface is mainly used to limit the number of input rows. */ - protected boolean stopEarly() { + public boolean stopEarly() { return false; } @@ -84,14 +88,14 @@ protected boolean stopEarly() { * * If it returns true, the caller should exit the loop (return from processNext()). */ - protected boolean shouldStop() { + public boolean shouldStop() { return !currentRows.isEmpty(); } /** * Increase the peak execution memory for current task. */ - protected void incPeakExecutionMemory(long size) { + public void incPeakExecutionMemory(long size) { TaskContext.get().taskMetrics().incPeakExecutionMemory(size); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index eb2fe82007af..9eb03430a7db 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.unsafe.sort.*; @@ -98,19 +99,33 @@ public UnsafeKVExternalSorter( numElementsForSpillThreshold, canUseRadixSort); } else { - // The array will be used to do in-place sort, which require half of the space to be empty. - // Note: each record in the map takes two entries in the array, one is record pointer, - // another is the key prefix. - assert(map.numKeys() * 2 <= map.getArray().size() / 2); - // During spilling, the array in map will not be used, so we can borrow that and use it - // as the underlying array for in-memory sorter (it's always large enough). - // Since we will not grow the array, it's fine to pass `null` as consumer. + // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow + // that and use it as the pointer array for `UnsafeInMemorySorter`. + LongArray pointerArray = map.getArray(); + // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but + // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since + // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer + // array can hold all the entries in `BytesToBytesMap`. + // The pointer array will be used to do in-place sort, which requires half of the space to be + // empty. Note: each record in the map takes two entries in the pointer array, one is record + // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`. + // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key, + // so that we can always reuse the pointer array. + if (map.numValues() > pointerArray.size() / 4) { + // Here we ask the map to allocate memory, so that the memory manager won't ask the map + // to spill, if the memory is not enough. + pointerArray = map.allocateArray(map.numValues() * 4L); + } + + // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed + // to be large enough, it's fine to pass `null` as consumer because we won't allocate more + // memory. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( null, taskMemoryManager, comparatorSupplier.get(), prefixComparator, - map.getArray(), + pointerArray, canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory @@ -241,7 +256,13 @@ private static final class KVComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1 + 4, 0); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java new file mode 100644 index 000000000000..82a1169cbe7a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * Exception thrown when the parquet reader find column type mismatches. + */ +@InterfaceStability.Unstable +public class SchemaColumnConvertNotSupportedException extends RuntimeException { + + /** + * Name of the column which cannot be converted. + */ + private String column; + /** + * Physical column type in the actual parquet file. + */ + private String physicalType; + /** + * Logical column type in the parquet schema the parquet reader use to parse all files. + */ + private String logicalType; + + public String getColumn() { + return column; + } + + public String getPhysicalType() { + return physicalType; + } + + public String getLogicalType() { + return logicalType; + } + + public SchemaColumnConvertNotSupportedException( + String column, + String physicalType, + String logicalType) { + super(); + this.column = column; + this.physicalType = physicalType; + this.logicalType = logicalType; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java new file mode 100644 index 000000000000..12f4d658b186 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. + */ +public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + private final boolean isTimestamp; + + private int batchSize; + + OrcColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java new file mode 100644 index 000000000000..dcebdc39f0aa --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -0,0 +1,555 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.io.IOException; +import java.util.stream.IntStream; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.*; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OrcColumnarBatchReader extends RecordReader { + + // The capacity of vectorized batch. + private int capacity; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column doesn't exist in the ORC file. + */ + private int[] requestedColIds; + + // Record reader from ORC row batch. + private org.apache.orc.RecordReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + private ColumnarBatch columnarBatch; + + // Writable column vectors of the result columnar batch. + private WritableColumnVector[] columnVectors; + + // The wrapped ORC column vectors. It should be null if `copyToSpark` is true. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + // The memory mode of the columnarBatch + private final MemoryMode MEMORY_MODE; + + // Whether or not to copy the ORC columnar batch to Spark columnar batch. + private final boolean copyToSpark; + + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark, int capacity) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + this.copyToSpark = copyToSpark; + this.capacity = capacity; + } + + + @Override + public Void getCurrentKey() { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + Reader reader = OrcFile.createReader( + fileSplit.getPath(), + OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + Reader.Options options = + OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); + recordReader = reader.rows(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + */ + public void initBatch( + TypeDescription orcSchema, + int[] requestedColIds, + StructField[] requiredFields, + StructType partitionSchema, + InternalRow partitionValues) { + batch = orcSchema.createRowBatch(capacity); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + + this.requiredFields = requiredFields; + this.requestedColIds = requestedColIds; + assert(requiredFields.length == requestedColIds.length); + + StructType resultSchema = new StructType(requiredFields); + for (StructField f : partitionSchema.fields()) { + resultSchema = resultSchema.add(f); + } + + if (copyToSpark) { + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + } + + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, capacity); + columnVectors[i].setIsConstant(); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + columnarBatch = new ColumnarBatch(columnVectors); + } else { + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + DataType dt = partitionSchema.fields()[i].dataType(); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + orcVectorWrappers[partitionIdx + i] = partitionCol; + } + } + + columnarBatch = new ColumnarBatch(orcVectorWrappers); + } + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + recordReader.nextBatch(batch); + int batchSize = batch.size; + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + + if (!copyToSpark) { + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] != -1) { + ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + return true; + } + + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + + for (int i = 0; i < requiredFields.length; i++) { + StructField field = requiredFields[i]; + WritableColumnVector toColumn = columnVectors[i]; + + if (requestedColIds[i] >= 0) { + ColumnVector fromColumn = batch.cols[requestedColIds[i]]; + + if (fromColumn.isRepeating) { + putRepeatingValues(batchSize, field, fromColumn, toColumn); + } else if (fromColumn.noNulls) { + putNonNullValues(batchSize, field, fromColumn, toColumn); + } else { + putValues(batchSize, field, fromColumn, toColumn); + } + } + } + return true; + } + + private void putRepeatingValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + if (fromColumn.isNull[0]) { + toColumn.putNulls(0, batchSize); + } else { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1); + } else if (type instanceof ByteType) { + toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof ShortType) { + toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof IntegerType || type instanceof DateType) { + toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof TimestampType) { + toColumn.putLongs(0, batchSize, + fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0)); + } else if (type instanceof FloatType) { + toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = (BytesColumnVector)fromColumn; + int size = data.vector[0].length; + toColumn.arrayData().reserve(size); + toColumn.arrayData().putBytes(0, size, data.vector[0], 0); + for (int index = 0; index < batchSize; index++) { + toColumn.putArray(index, 0, size); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + putDecimalWritables( + toColumn, + batchSize, + decimalType.precision(), + decimalType.scale(), + ((DecimalColumnVector)fromColumn).vector[0]); + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + } + + private void putNonNullValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putBoolean(index, data[index] == 1); + } + } else if (type instanceof ByteType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putByte(index, (byte)data[index]); + } + } else if (type instanceof ShortType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putShort(index, (short)data[index]); + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putInt(index, (int)data[index]); + } + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0); + } else if (type instanceof TimestampType) { + TimestampColumnVector data = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + toColumn.putLong(index, fromTimestampColumnVector(data, index)); + } + } else if (type instanceof FloatType) { + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putFloat(index, (float)data[index]); + } + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = ((BytesColumnVector)fromColumn); + WritableColumnVector arrayData = toColumn.arrayData(); + int totalNumBytes = IntStream.of(data.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { + arrayData.putBytes(pos, data.length[index], data.vector[index], data.start[index]); + toColumn.putArray(index, pos, data.length[index]); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + DecimalColumnVector data = ((DecimalColumnVector)fromColumn); + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + toColumn.arrayData().reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + data.vector[index]); + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + private void putValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putBoolean(index, vector[index] == 1); + } + } + } else if (type instanceof ByteType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putByte(index, (byte)vector[index]); + } + } + } else if (type instanceof ShortType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putShort(index, (short)vector[index]); + } + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putInt(index, (int)vector[index]); + } + } + } else if (type instanceof LongType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, vector[index]); + } + } + } else if (type instanceof TimestampType) { + TimestampColumnVector vector = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, fromTimestampColumnVector(vector, index)); + } + } + } else if (type instanceof FloatType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putFloat(index, (float)vector[index]); + } + } + } else if (type instanceof DoubleType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putDouble(index, vector[index]); + } + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector vector = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.arrayData(); + int totalNumBytes = IntStream.of(vector.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + arrayData.putBytes(pos, vector.length[index], vector.vector[index], vector.start[index]); + toColumn.putArray(index, pos, vector.length[index]); + } + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + toColumn.arrayData().reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + vector[index]); + } + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + /** + * Returns the number of micros since epoch from an element of TimestampColumnVector. + */ + private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { + return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + } + + /** + * Put a `HiveDecimalWritable` to a `WritableColumnVector`. + */ + private static void putDecimalWritable( + WritableColumnVector toColumn, + int index, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInt(index, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLong(index, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0); + toColumn.putArray(index, index * 16, bytes.length); + } + } + + /** + * Put `HiveDecimalWritable`s to a `WritableColumnVector`. + */ + private static void putDecimalWritables( + WritableColumnVector toColumn, + int size, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInts(0, size, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLongs(0, size, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + toColumn.arrayData().reserve(bytes.length); + toColumn.arrayData().putBytes(0, bytes.length, bytes, 0); + for (int index = 0; index < size; index++) { + toColumn.putArray(index, 0, bytes.length); + } + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 80c2f491b48c..e65cd252c3dd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -170,7 +170,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont * Returns the list of files at 'path' recursively. This skips files that are ignored normally * by MapReduce. */ - public static List listDirectory(File path) throws IOException { + public static List listDirectory(File path) { List result = new ArrayList<>(); if (path.isDirectory()) { for (File f: path.listFiles()) { @@ -231,7 +231,7 @@ protected void initialize(String path, List columns) throws IOException } @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @@ -259,7 +259,7 @@ public ValuesReaderIntIterator(ValuesReader delegate) { } @Override - int nextInt() throws IOException { + int nextInt() { return delegate.readInteger(); } } @@ -279,15 +279,15 @@ int nextInt() throws IOException { protected static final class NullIntIterator extends IntIterator { @Override - int nextInt() throws IOException { return 0; } + int nextInt() { return 0; } } /** * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, - ColumnDescriptor descriptor) throws IOException { + protected static IntIterator createRLEIterator( + int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); return new RLEIntIterator( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c120863152a9..72f1d024b08c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.Arrays; import java.util.TimeZone; import org.apache.parquet.bytes.BytesUtils; @@ -31,6 +32,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -231,6 +233,18 @@ private boolean shouldConvertTimestamps() { return convertTz != null && !convertTz.equals(UTC); } + /** + * Helper function to construct exception for parquet schema mismatch. + */ + private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException( + ColumnDescriptor descriptor, + WritableColumnVector column) { + return new SchemaColumnConvertNotSupportedException( + Arrays.toString(descriptor.getPath()), + descriptor.getType().toString(), + column.dataType().toString()); + } + /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ @@ -261,7 +275,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -282,7 +296,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -321,7 +335,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; case BINARY: @@ -360,7 +374,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -375,7 +389,9 @@ private void decodeDictionaryIds( */ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { - assert(column.dataType() == DataTypes.BooleanType); + if (column.dataType() != DataTypes.BooleanType) { + throw constructConvertNotSupportedException(descriptor, column); + } defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } @@ -394,7 +410,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -414,7 +430,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -425,7 +441,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { defColumn.readFloats( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -436,7 +452,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { defColumn.readDoubles( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -444,7 +460,8 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType + || DecimalType.isByteArrayDecimalType(column.dataType())) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { if (!shouldConvertTimestamps()) { @@ -470,7 +487,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -509,7 +526,7 @@ private void readFixedLenByteArrayBatch( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411..5934a23db8af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -50,6 +50,10 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { + + // The capacity of vectorized batch. + private int capacity; + /** * Batch of rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -112,13 +116,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) { + public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap, int capacity) { this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; - } - - public VectorizedParquetRecordReader(boolean useOffHeap) { - this(null, useOffHeap); + this.capacity = capacity; } /** @@ -152,7 +153,7 @@ public void close() throws IOException { } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { resultBatch(); if (returnColumnarBatch) return nextBatch(); @@ -165,13 +166,13 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } @Override - public Object getCurrentValue() throws IOException, InterruptedException { + public Object getCurrentValue() { if (returnColumnarBatch) return columnarBatch; return columnarBatch.getRow(batchIdx - 1); } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return (float) rowsReturned / totalRowCount; } @@ -181,7 +182,7 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch( + private void initBatch( MemoryMode memMode, StructType partitionColumns, InternalRow partitionValues) { @@ -195,13 +196,12 @@ public void initBatch( } } - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); } else { columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); } - columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -213,13 +213,13 @@ public void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } } - public void initBatch() { + private void initBatch() { initBatch(MEMORY_MODE, null, null); } @@ -248,11 +248,14 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java deleted file mode 100644 index dc7c1269bedd..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.vectorized; - -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. - * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. - * - * Maps are just a special case of a two field struct. - * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. - */ -public abstract class ColumnVector implements AutoCloseable { - - /** - * Returns the data type of this column. - */ - public final DataType dataType() { return type; } - - /** - * Cleans up memory for this column. The column is not usable after this. - */ - public abstract void close(); - - /** - * Returns the number of nulls in this column. - */ - public abstract int numNulls(); - - /** - * Returns whether the value at rowId is NULL. - */ - public abstract boolean isNullAt(int rowId); - - /** - * Returns the value for rowId. - */ - public abstract boolean getBoolean(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract boolean[] getBooleans(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract byte getByte(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract byte[] getBytes(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract short getShort(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract short[] getShorts(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract int getInt(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract int[] getInts(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract long getLong(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract long[] getLongs(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract float getFloat(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract float[] getFloats(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract double getDouble(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract double[] getDoubles(int rowId, int count); - - /** - * Returns the length of the array for rowId. - */ - public abstract int getArrayLength(int rowId); - - /** - * Returns the offset of the array for rowId. - */ - public abstract int getArrayOffset(int rowId); - - /** - * Returns the struct for rowId. - */ - public final ColumnarRow getStruct(int rowId) { - return new ColumnarRow(this, rowId); - } - - /** - * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark - * codegen framework, the second parameter is totally ignored. - */ - public final ColumnarRow getStruct(int rowId, int size) { - return getStruct(rowId); - } - - /** - * Returns the array for rowId. - */ - public final ColumnarArray getArray(int rowId) { - return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); - } - - /** - * Returns the map for rowId. - */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - /** - * Returns the decimal for rowId. - */ - public abstract Decimal getDecimal(int rowId, int precision, int scale); - - /** - * Returns the UTF8String for rowId. Note that the returned UTF8String may point to the data of - * this column vector, please copy it if you want to keep it after this column vector is freed. - */ - public abstract UTF8String getUTF8String(int rowId); - - /** - * Returns the byte array for rowId. - */ - public abstract byte[] getBinary(int rowId); - - /** - * Returns the data for the underlying array. - */ - public abstract ColumnVector arrayData(); - - /** - * Returns the ordinal's child data column. - */ - public abstract ColumnVector getChildColumn(int ordinal); - - /** - * Data type for this column. - */ - protected DataType type; - - /** - * Sets up the common state and also handles creating the child columns if this is a nested - * type. - */ - protected ColumnVector(DataType type) { - this.type = type; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e..829f3ce750fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -20,14 +20,19 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -83,8 +88,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); - col.getChildColumn(0).putInts(0, capacity, c.months); - col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + col.getChild(0).putInts(0, capacity, c.months); + col.getChild(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { col.putInts(0, capacity, row.getInt(fieldIdx)); } else if (t instanceof TimestampType) { @@ -107,6 +112,18 @@ public static int[] toJavaIntArray(ColumnarArray array) { return array.toIntArray(); } + public static Map toJavaIntMap(ColumnarMap map) { + int[] keys = toJavaIntArray(map.keyArray()); + int[] values = toJavaIntArray(map.valueArray()); + assert keys.length == values.length; + + Map result = new HashMap<>(); + for (int i = 0; i < keys.length; i++) { + result.put(keys[i], values[i]); + } + return result; + } + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { @@ -116,19 +133,19 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } } else { if (t == DataTypes.BooleanType) { - dst.appendBoolean(((Boolean)o).booleanValue()); + dst.appendBoolean((Boolean) o); } else if (t == DataTypes.ByteType) { - dst.appendByte(((Byte) o).byteValue()); + dst.appendByte((Byte) o); } else if (t == DataTypes.ShortType) { - dst.appendShort(((Short)o).shortValue()); + dst.appendShort((Short) o); } else if (t == DataTypes.IntegerType) { - dst.appendInt(((Integer)o).intValue()); + dst.appendInt((Integer) o); } else if (t == DataTypes.LongType) { - dst.appendLong(((Long)o).longValue()); + dst.appendLong((Long) o); } else if (t == DataTypes.FloatType) { - dst.appendFloat(((Float)o).floatValue()); + dst.appendFloat((Float) o); } else if (t == DataTypes.DoubleType) { - dst.appendDouble(((Double)o).doubleValue()); + dst.appendDouble((Double) o); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); @@ -147,8 +164,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)o; dst.appendStruct(false); - dst.getChildColumn(0).appendInt(c.months); - dst.getChildColumn(1).appendLong(c.microseconds); + dst.getChild(0).appendInt(c.months); + dst.getChild(1).appendLong(c.microseconds); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { @@ -177,7 +194,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i dst.appendStruct(false); Row c = src.getStruct(fieldIdx); for (int i = 0; i < st.fields().length; i++) { - appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i); } } } else { @@ -190,7 +207,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + int capacity = 4 * 1024; WritableColumnVector[] columnVectors; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); @@ -206,7 +223,7 @@ public static ColumnarBatch toBatch( } n++; } - ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); + ColumnarBatch batch = new ColumnarBatch(columnVectors); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe..4e4242fe8d9b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -21,8 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -123,45 +127,37 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return columns[ordinal].getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return columns[ordinal].getMap(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1c45b846790b..4733f36174f4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -60,7 +61,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long nulls; private long data; - // Set iff the type is array. + // Only set if type is Array or Map. private long lengthData; private long offsetData; @@ -123,7 +124,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); @@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return UTF8String.fromAddress(null, data + rowId, count); + return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); } // @@ -530,7 +531,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (isArray()) { + if (isArray() || type instanceof MapType) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 1d538fe4181b..23dcc104e67c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -69,7 +69,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private float[] floatData; private double[] doubleData; - // Only set if type is Array. + // Only set if type is Array or Map. private int[] arrayLengths; private int[] arrayOffsets; @@ -119,7 +119,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } @@ -503,7 +503,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (isArray()) { + if (isArray() || type instanceof MapType) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e1..b0e119d658cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,9 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -57,8 +60,8 @@ public void reset() { elementsAppended = 0; if (numNulls > 0) { putNotNulls(0, capacity); + numNulls = 0; } - numNulls = 0; } @Override @@ -78,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -93,13 +98,24 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } + @Override + public boolean hasNull() { + return numNulls > 0; + } + @Override public int numNulls() { return numNulls; } @@ -330,6 +346,7 @@ public final int putByteArray(int rowId, byte[] value) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -356,6 +373,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytesAsUTF8String(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -373,6 +391,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytes(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -585,11 +604,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { @@ -598,17 +617,32 @@ public final int appendStruct(boolean isNull) { return elementsAppended; } - /** - * Returns the data for the underlying array. - */ + // `WritableColumnVector` puts the data of array in the first child column vector, and puts the + // array offsets and lengths in the current column vector. @Override - public WritableColumnVector arrayData() { return childColumns[0]; } + public final ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); + } - /** - * Returns the ordinal's child data column. - */ + // `WritableColumnVector` puts the key array in the first child column vector, value array in the + // second child column vector, and puts the offsets and lengths in the current column vector. @Override - public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public final ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; + return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); + } + + public WritableColumnVector arrayData() { + return childColumns[0]; + } + + public abstract int getArrayLength(int rowId); + + public abstract int getArrayOffset(int rowId); + + @Override + public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } /** * Returns the elements appended. @@ -692,6 +726,11 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } + } else if (type instanceof MapType) { + MapType mapType = (MapType) type; + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java similarity index 87% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java index 3136cee1f655..7df5a451ae5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java @@ -15,19 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2; import java.util.Optional; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability for continuous stream processing. */ +@InterfaceStability.Evolving public interface ContinuousReadSupport extends DataSourceV2 { /** * Creates a {@link ContinuousReader} to scan the data from this data source. @@ -42,5 +42,5 @@ public interface ContinuousReadSupport extends DataSourceV2 { ContinuousReader createContinuousReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java new file mode 100644 index 000000000000..83df3be74708 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An immutable string-to-string map in which keys are case-insensitive. This is used to represent + * data source options. + * + * Each data source implementation can define its own options and teach its users how to set them. + * Spark doesn't have any restrictions about what options a data source should or should not have. + * Instead Spark defines some standard options that data sources can optionally adopt. It's possible + * that some options are very common and many data sources use them. However different data + * sources may define the common options(key and meaning) differently, which is quite confusing to + * end users. + * + * The standard options defined by Spark: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
      Option keyOption value
      pathA path string of the data files/directories, like + * path1, /absolute/file2, path3/*. The path can + * either be relative or absolute, points to either file or directory, and can contain + * wildcards. This option is commonly used by file-based data sources.
      pathsA JSON array style paths string of the data files/directories, like + * ["path1", "/absolute/file2"]. The format of each path is same as the + * path option, plus it should follow JSON string literal format, e.g. quotes + * should be escaped, pa\"th means pa"th. + *
      tableA table name string representing the table name directly without any interpretation. + * For example, db.tbl means a table called db.tbl, not a table called tbl + * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
      databaseA database name string representing the database name directly without any + * interpretation, which is very similar to the table name option.
      + */ +@InterfaceStability.Evolving +public class DataSourceOptions { + private final Map keyLowerCasedMap; + + private String toLowerCase(String key) { + return key.toLowerCase(Locale.ROOT); + } + + public static DataSourceOptions empty() { + return new DataSourceOptions(new HashMap<>()); + } + + public DataSourceOptions(Map originalMap) { + keyLowerCasedMap = new HashMap<>(originalMap.size()); + for (Map.Entry entry : originalMap.entrySet()) { + keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); + } + } + + public Map asMap() { + return new HashMap<>(keyLowerCasedMap); + } + + /** + * Returns the option value to which the specified key is mapped, case-insensitively. + */ + public Optional get(String key) { + return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); + } + + /** + * Returns the boolean value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public boolean getBoolean(String key, boolean defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the integer value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public int getInt(String key, int defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the long value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public long getLong(String key, long defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the double value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public double getDouble(String key, double defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * The option key for singular path. + */ + public static final String PATH_KEY = "path"; + + /** + * The option key for multiple paths. + */ + public static final String PATHS_KEY = "paths"; + + /** + * The option key for table name. + */ + public static final String TABLE_KEY = "table"; + + /** + * The option key for database name. + */ + public static final String DATABASE_KEY = "database"; + + /** + * Returns all the paths specified by both the singular path option and the multiple + * paths option. + */ + public String[] paths() { + String[] singularPath = + get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); + Optional pathsStr = get(PATHS_KEY); + if (pathsStr.isPresent()) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); + return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); + } catch (IOException e) { + return singularPath; + } + } else { + return singularPath; + } + } + + /** + * Returns the value of the table name option. + */ + public Optional tableName() { + return get(TABLE_KEY); + } + + /** + * Returns the value of the database name option. + */ + public Optional databaseName() { + return get(DATABASE_KEY); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java deleted file mode 100644 index ddc2acca693a..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; - -/** - * An immutable string-to-string map in which keys are case-insensitive. This is used to represent - * data source options. - */ -@InterfaceStability.Evolving -public class DataSourceV2Options { - private final Map keyLowerCasedMap; - - private String toLowerCase(String key) { - return key.toLowerCase(Locale.ROOT); - } - - public static DataSourceV2Options empty() { - return new DataSourceV2Options(new HashMap<>()); - } - - public DataSourceV2Options(Map originalMap) { - keyLowerCasedMap = new HashMap<>(originalMap.size()); - for (Map.Entry entry : originalMap.entrySet()) { - keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); - } - } - - public Map asMap() { - return new HashMap<>(keyLowerCasedMap); - } - - /** - * Returns the option value to which the specified key is mapped, case-insensitively. - */ - public Optional get(String key) { - return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); - } - - /** - * Returns the boolean value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public boolean getBoolean(String key, boolean defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the integer value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public int getInt(String key, int defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the long value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public long getLong(String key, long defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the double value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public double getDouble(String key, double defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java similarity index 82% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 3c87a3db6824..209ffa7a0b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** @@ -36,8 +34,8 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createReadTasks for each batch to process, and then - * call stop() when the execution is complete. Note that a single query may have multiple + * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * * @param schema the user provided schema, or empty() if none was provided @@ -50,5 +48,5 @@ public interface MicroBatchReadSupport extends DataSourceV2 { MicroBatchReader createMicroBatchReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 948e20bacf4a..0ea4dc6b5def 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -18,17 +18,17 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability and scan the data from the data source. */ @InterfaceStability.Evolving -public interface ReadSupport { +public interface ReadSupport extends DataSourceV2 { /** - * Creates a {@link DataSourceV2Reader} to scan the data from this data source. + * Creates a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -36,5 +36,5 @@ public interface ReadSupport { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(DataSourceV2Options options); + DataSourceReader createReader(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index b69c6bed8d1b..3801402268af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; /** @@ -30,10 +30,10 @@ * supports both schema inference and user-specified schema. */ @InterfaceStability.Evolving -public interface ReadSupportWithSchema { +public interface ReadSupportWithSchema extends DataSourceV2 { /** - * Create a {@link DataSourceV2Reader} to scan the data from this data source. + * Create a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -45,5 +45,5 @@ public interface ReadSupportWithSchema { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); + DataSourceReader createReader(StructType schema, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 3cb020d2e083..9d66805d79b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -25,7 +25,7 @@ * session. */ @InterfaceStability.Evolving -public interface SessionConfigSupport { +public interface SessionConfigSupport extends DataSourceV2 { /** * Key prefix of the session configs to propagate. Spark will extract all session configs that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java similarity index 71% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java index dee493cadb71..a77b01497269 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java @@ -15,42 +15,38 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; - -import java.util.Optional; +package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for continuous stream processing. + * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface ContinuousWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { /** - * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data + * Creates an optional {@link StreamWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * * @param queryId A unique string for the writing query. It's possible that there are many * writing queries running at the same time, and the returned - * {@link DataSourceV2Writer} can use this id to distinguish itself from others. + * {@link DataSourceWriter} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createContinuousWriter( + StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 1e3b644d8c4a..cab56453816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; /** @@ -29,17 +29,17 @@ * provide data writing ability and save the data to the data source. */ @InterfaceStability.Evolving -public interface WriteSupport { +public interface WriteSupport extends DataSourceV2 { /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceV2Writer} can + * jobs running at the same time, and the returned {@link DataSourceWriter} can * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data @@ -47,6 +47,6 @@ public interface WriteSupport { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java new file mode 100644 index 000000000000..a61697649c43 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; + +/** + * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can + * implement this interface to provide creating {@link DataReader} with particular offset. + */ +@InterfaceStability.Evolving +public interface ContinuousDataReaderFactory extends DataReaderFactory { + /** + * Create a DataReader with particular offset as its startOffset. + * + * @param offset offset want to set as the DataReader's startOffset. + */ + DataReader createDataReaderWithOffset(PartitionOffset offset); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 8f58c865b620..bb9790a1c819 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link ReadTask#createDataReader()} and is responsible for + * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java similarity index 65% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index fa161cdb8b34..32e98e8f5d8b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,21 +22,23 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for - * creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} + * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is + * responsible for creating the actual data reader. The relationship between + * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. So {@link ReadTask} must be serializable and - * {@link DataReader} doesn't need to be. + * Note that, the reader factory will be serialized and sent to executors, then the data reader + * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be + * serializable and {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface ReadTask extends Serializable { +public interface DataReaderFactory extends Serializable { /** - * The preferred locations where this read task can run faster, but Spark does not guarantee that - * this task will always run on these locations. The implementations should make sure that it can - * be run on any location. The location is a string representing the host name. + * The preferred locations where the data reader returned by this reader factory can run faster, + * but Spark does not guarantee to run the data reader on these locations. + * The implementations should make sure that it can be run on any location. + * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in * the returned locations. By default this method returns empty string array, which means this @@ -50,7 +52,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work for this read task. + * Returns a data reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java similarity index 70% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 95ee4a827832..a470bccc5aad 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -21,16 +21,18 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.types.StructType; /** * A data source reader that is returned by - * {@link org.apache.spark.sql.sources.v2.ReadSupport#createReader( - * org.apache.spark.sql.sources.v2.DataSourceV2Options)} or - * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( - * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. + * {@link ReadSupport#createReader(DataSourceOptions)} or + * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link DataReaderFactory}s that are returned by + * {@link #createDataReaderFactories()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -38,7 +40,10 @@ * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. + * Names of these interfaces start with `SupportsScan`. Note that a reader should only + * implement at most one of the special scans, if more than one special scans are implemented, + * only one of them would be respected, according to the priority list from high to low: + * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. @@ -48,7 +53,7 @@ * issues the scan request and does the actual data reading. */ @InterfaceStability.Evolving -public interface DataSourceV2Reader { +public interface DataSourceReader { /** * Returns the actual schema of this data source reader, which may be different from the physical @@ -60,9 +65,9 @@ public interface DataSourceV2Reader { StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for outputting data for one RDD - * partition. That means the number of tasks returned here is same as the number of RDD - * partitions this scan outputs. + * Returns a list of reader factories. Each factory is responsible for creating a data reader to + * output data for one RDD partition. That means the number of factories returned here is same as + * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before @@ -71,5 +76,5 @@ public interface DataSourceV2Reader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createReadTasks(); + List> createDataReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index f76c687f450c..4543c143a9ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down arbitrary expressions as predicates to the data source. * This is an experimental and unstable interface as {@link Expression} is not public and may get * changed in the future Spark versions. @@ -31,15 +31,24 @@ * process this interface. */ @InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters { +public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Expression[] pushCatalystFilters(Expression[] filters); /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * Returns the catalyst filters that are pushed to the data source via + * {@link #pushCatalystFilters(Expression[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for * this case. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 6b0c9d417eea..b6a90a3d0b68 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down filters to the data source and reduce the size of the data to be read. * * Note that, if data source readers implement both this interface and @@ -29,15 +29,23 @@ * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters { +public interface SupportsPushDownFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index fe0ac8ee0ee3..427b4d00a112 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns { +public interface SupportsPushDownRequiredColumns extends DataSourceReader { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,7 +35,7 @@ public interface SupportsPushDownRequiredColumns { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceV2Reader#readSchema()} after + * Note that, data source readers should update {@link DataSourceReader#readSchema()} after * applying column pruning. */ void pruneColumns(StructType requiredSchema); diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java similarity index 52% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index c0e7bb20cce8..607628746e87 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -14,26 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer -import org.apache.spark.deploy.k8s.MountSecretsBootstrap +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * An init-container configuration step for mounting user-specified secrets onto user-specified - * paths. + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to report data partitioning and try to avoid shuffle at Spark side. * - * @param bootstrap a utility actually handling mounting of the secrets + * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * adding a shuffle even if the reader does not implement this interface. */ -private[spark] class InitContainerMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { +@InterfaceStability.Evolving +public interface SupportsReportPartitioning extends DataSourceReader { - override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) - } + /** + * Returns the output data partitioning that this reader guarantees. + */ + Partitioning outputPartitioning(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index c019d2f819ab..11bb13fd3b21 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics { +public interface SupportsReportStatistics extends DataSourceReader { /** * Returns the basic statistics of this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 000000000000..2e5cfa78511f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ +@InterfaceStability.Evolving +public interface SupportsScanColumnarBatch extends DataSourceReader { + @Override + default List> createDataReaderFactories() { + throw new IllegalStateException( + "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data + * in batches. + */ + List> createBatchDataReaderFactories(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index b90ec880dc85..9cd749e8e4ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -24,22 +24,23 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceV2Reader { +public interface SupportsScanUnsafeRow extends DataSourceReader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks should not be called with SupportsScanUnsafeRow."); + "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. + * Similar to {@link DataSourceReader#createDataReaderFactories()}, + * but returns data in unsafe row format. */ - List> createUnsafeRowReadTasks(); + List> createUnsafeRowReaderFactories(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 53% rename from common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 74ebc87dc978..2d0ee50212b5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -15,40 +15,25 @@ * limitations under the License. */ -package org.apache.spark.unsafe.memory; +package org.apache.spark.sql.sources.v2.reader.partitioning; -import javax.annotation.Nullable; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** - * A memory location. Tracked either by a memory address (with off-heap allocation), - * or by an offset from a JVM object (in-heap allocation). + * A concrete implementation of {@link Distribution}. Represents a distribution where records that + * share the same values for the {@link #clusteredColumns} will be produced by the same + * {@link DataReader}. */ -public class MemoryLocation { +@InterfaceStability.Evolving +public class ClusteredDistribution implements Distribution { - @Nullable - Object obj; + /** + * The names of the clustered columns. Note that they are order insensitive. + */ + public final String[] clusteredColumns; - long offset; - - public MemoryLocation(@Nullable Object obj, long offset) { - this.obj = obj; - this.offset = offset; - } - - public MemoryLocation() { - this(null, 0); - } - - public void setObjAndOffset(Object newObj, long newOffset) { - this.obj = newObj; - this.offset = newOffset; - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; + public ClusteredDistribution(String[] clusteredColumns) { + this.clusteredColumns = clusteredColumns; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java new file mode 100644 index 000000000000..f6b111fdf220 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.partitioning; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; + +/** + * An interface to represent data distribution requirement, which specifies how the records should + * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * Note that this interface has nothing to do with the data ordering inside one + * partition(the output records of a single {@link DataReader}). + * + * The instance of this interface is created and provided by Spark, then consumed by + * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to + * implement this interface, but need to catch as more concrete implementations of this interface + * as possible in {@link Partitioning#satisfy(Distribution)}. + * + * Concrete implementations until now: + *
        + *
      • {@link ClusteredDistribution}
      • + *
      + */ +@InterfaceStability.Evolving +public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java new file mode 100644 index 000000000000..309d9e5de0a0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.partitioning; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; + +/** + * An interface to represent the output data partitioning for a data source, which is returned by + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of + * partitions and the same "satisfy" result for a certain distribution. + */ +@InterfaceStability.Evolving +public interface Partitioning { + + /** + * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. + */ + int numPartitions(); + + /** + * Returns true if this partitioning can satisfy the given distribution, which means Spark does + * not need to shuffle the output data of this data source for some certain operations. + * + * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases. + * This method should be aware of it and always return false for unrecognized distributions. It's + * recommended to check every Spark new release and support new distributions if possible, to + * avoid shuffle at Spark side for more cases. + */ + boolean satisfy(Distribution distribution); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java index ca9a290e97a0..47d26440841f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. */ +@InterfaceStability.Evolving public interface ContinuousDataReader extends DataReader { /** * Get the offset of the current record, or the start offset if no records have been read. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index f0b205869ed6..7fe7f00ac2fa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -15,23 +15,28 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ -public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { +@InterfaceStability.Evolving +public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to - * a single global offset. + * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); @@ -42,23 +47,23 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset deserializeOffset(String json); /** - * Set the desired start offset for read tasks created from this reader. The scan will start - * from the first record after the provided offset, or from an implementation-defined inferred - * starting point if no offset is provided. + * Set the desired start offset for reader factories created from this reader. The scan will + * start from the first record after the provided offset, or from an implementation-defined + * inferred starting point if no offset is provided. */ - void setOffset(Optional start); + void setStartOffset(Optional start); /** * Return the specified or inferred start offset for this reader. * - * @throws IllegalStateException if setOffset has not been called + * @throws IllegalStateException if setStartOffset has not been called */ Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new read tasks need - * to be generated, which may be required if for example the underlying source system has had - * partitions added or removed. + * The execution engine will call this method in every epoch to determine if new reader + * factories need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java similarity index 75% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index 70ff75680603..67ebde30d61a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -15,22 +15,27 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ -public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { +@InterfaceStability.Evolving +public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** - * Set the desired offset range for read tasks created from this reader. Read tasks will - * generate only data within (`start`, `end`]; that is, from the first record after `start` to - * the record with offset `end`. + * Set the desired offset range for reader factories created from this reader. Reader factories + * will generate only data within (`start`, `end`]; that is, from the first record after `start` + * to the record with offset `end`. * * @param start The initial offset to scan from. If not specified, scan from an * implementation-specified start point, such as the earliest available record. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 70% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 60b87f2ac075..e41c0351edc8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -15,14 +15,22 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. - * During execution, Offsets provided by the data source implementation will be logged and used as - * restart checkpoints. Sources should provide an Offset implementation which they can use to - * reconstruct the stream position where the offset was taken. + * An abstract representation of progress through a {@link MicroBatchReader} or + * {@link ContinuousReader}. + * During execution, offsets provided by the data source implementation will be logged and used as + * restart checkpoints. Each source should provide an offset implementation which the source can use + * to reconstruct a position in the stream up to which data has been seen/processed. + * + * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to + * maintain compatibility with DataSource V1 APIs. This extension will be removed once we + * get rid of V1 completely. */ +@InterfaceStability.Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is @@ -37,7 +45,7 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of /** * Equality based on JSON string representation. We leverage the * JSON representation for normalization between the Offset's - * in memory and on disk representations. + * in deserialized and serialized representations. */ @Override public boolean equals(Object obj) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 88% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index eca0085c8a8c..383e73db6762 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -15,15 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import java.io.Serializable; +import org.apache.spark.annotation.InterfaceStability; + /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will * provide a method to merge these into a global Offset. * * These offsets must be serializable. */ +@InterfaceStability.Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java deleted file mode 100644 index 53ffa95ae0f4..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.streaming; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data from a microbatch to the data source. - */ -@InterfaceStability.Evolving -public interface MicroBatchWriteSupport extends BaseStreamingSink { - - /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. - * @param epochId The unique numeric ID of the batch within this writing query. This is an - * incrementing counter representing a consistent set of data; the same batch may - * be started multiple times in failure recovery scenarios, but it will always - * contain the same records. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive batch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - Optional createMicroBatchWriter( - String queryId, - long epochId, - StructType schema, - OutputMode mode, - DataSourceV2Options options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java deleted file mode 100644 index 723395bd1e96..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.streaming.writer; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; -import org.apache.spark.sql.sources.v2.writer.DataWriter; -import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; - -/** - * A {@link DataSourceV2Writer} for use with continuous stream processing. - */ -@InterfaceStability.Evolving -public interface ContinuousWriter extends DataSourceV2Writer { - /** - * Commits this writing job for the specified epoch with a list of commit messages. The commit - * messages are collected from successful data writers and are produced by - * {@link DataWriter#commit()}. - * - * If this method fails (by throwing an exception), this writing job is considered to have been - * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. - */ - void commit(long epochId, WriterCommitMessage[] messages); - - default void commit(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Commit without epoch should not be called with ContinuousWriter"); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java similarity index 70% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index fc37b9a516f8..0a0fd8db5803 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -20,13 +20,17 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ + * {@link StreamWriteSupport#createStreamWriter( + * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * @@ -49,7 +53,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceV2Writer { +public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. @@ -59,6 +63,24 @@ public interface DataSourceV2Writer { */ DataWriterFactory createWriterFactory(); + /** + * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for + * each task commits. + * + * @return true if commit coordinator should be used, false otherwise. + */ + default boolean useCommitCoordinator() { + return true; + } + + /** + * Handles a commit message on receiving from a successful data writer. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. + */ + default void onDataWriterCommit(WriterCommitMessage message) {} + /** * Commits this writing job with a list of commit messages. The commit messages are collected from * successful data writers and are produced by {@link DataWriter#commit()}. @@ -67,16 +89,19 @@ public interface DataSourceV2Writer { * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * - * Note that, one partition may have multiple committed data writers because of speculative tasks. - * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data - * writer can commit, or have a way to clean up the data of already-committed writers. + * Note that speculative execution may cause multiple tasks to run for a partition. By default, + * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple + * attempts may have committed successfully and one successful commit message per task will be + * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * Aborts this writing job because some data writers are failed and keep failing when retry, + * or the Spark job fails with some unknown reasons, + * or {@link #onDataWriterCommit(WriterCommitMessage)} fails, + * or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 04b03e63de50..39bf45829886 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -31,13 +31,17 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle + * is over and Spark will not use it again. + * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an - * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * exception will be sent to the driver side, and Spark may retry this writing task a few times. + * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a + * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the @@ -69,11 +73,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -91,7 +95,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 18ec792f5a2c..c2c2ab73257e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer @@ -48,6 +48,9 @@ public interface DataWriterFactory extends Serializable { * same task id but different attempt number, which means there are multiple * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java index 3e0518814f45..d2cf7e01c08c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -22,14 +22,14 @@ import org.apache.spark.sql.catalyst.InternalRow; /** - * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceV2Writer { +public interface SupportsWriteInternalRow extends DataSourceWriter { @Override default DataWriterFactory createWriterFactory() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 082d6b5dc409..9e38836c0edf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -23,10 +23,10 @@ /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} * implementations. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java new file mode 100644 index 000000000000..a316b2a4c1d8 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; + +/** + * A {@link DataSourceWriter} for use with structured streaming. + * + * Streaming queries are divided into intervals of data called epochs, with a monotonically + * increasing numeric ID. This writer handles commits and aborts for each successive epoch. + */ +@InterfaceStability.Evolving +public interface StreamWriter extends DataSourceWriter { + /** + * Commits this writing job for the specified epoch with a list of commit messages. The commit + * messages are collected from successful data writers and are produced by + * {@link DataWriter#commit()}. + * + * If this method fails (by throwing an exception), this writing job is considered to have been + * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * + * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * To support exactly-once data semantics, implementations must ensure that multiple commits for + * the same epoch are idempotent. + */ + void commit(long epochId, WriterCommitMessage[] messages); + + /** + * Aborts this writing job because some data writers are failed and keep failing when retried, or + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * + * If this method fails (by throwing an exception), the underlying data source may require manual + * cleanup. + * + * Unless the abort is triggered by the failure of commit, the given messages will have some + * null slots, as there may be only a few data writers that were committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(long epochId, WriterCommitMessage[] messages); + + default void commit(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Commit without epoch should not be called with StreamWriter"); + } + + default void abort(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Abort without epoch should not be called with StreamWriter"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 33ae9a9e8766..5371a23230c9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -50,7 +50,7 @@ public static Trigger ProcessingTime(long intervalMs) { * * {{{ * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) * }}} * * @since 2.2.0 @@ -66,7 +66,7 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { * * {{{ * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) * }}} * @since 2.2.0 */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 70% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a50..227a16f7e69e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,38 +15,32 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. + * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not + * supported. */ +@InterfaceStability.Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; - private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } - } - - private void ensureAccessible(int index, int count) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index + count > valueCount) { - throw new IndexOutOfBoundsException( - String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); - } + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; } @Override @@ -59,218 +53,84 @@ public void close() { if (childColumns != null) { for (int i = 0; i < childColumns.length; i++) { childColumns[i].close(); + childColumns[i] = null; } + childColumns = null; } accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { - ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { - ensureAccessible(rowId); return accessor.getBoolean(rowId); } - @Override - public boolean[] getBooleans(int rowId, int count) { - ensureAccessible(rowId, count); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getBoolean(rowId + i); - } - return array; - } - - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { - ensureAccessible(rowId); return accessor.getByte(rowId); } - @Override - public byte[] getBytes(int rowId, int count) { - ensureAccessible(rowId, count); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getByte(rowId + i); - } - return array; - } - - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { - ensureAccessible(rowId); return accessor.getShort(rowId); } - @Override - public short[] getShorts(int rowId, int count) { - ensureAccessible(rowId, count); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getShort(rowId + i); - } - return array; - } - - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { - ensureAccessible(rowId); return accessor.getInt(rowId); } - @Override - public int[] getInts(int rowId, int count) { - ensureAccessible(rowId, count); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getInt(rowId + i); - } - return array; - } - - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { - ensureAccessible(rowId); return accessor.getLong(rowId); } - @Override - public long[] getLongs(int rowId, int count) { - ensureAccessible(rowId, count); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getLong(rowId + i); - } - return array; - } - - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { - ensureAccessible(rowId); return accessor.getFloat(rowId); } - @Override - public float[] getFloats(int rowId, int count) { - ensureAccessible(rowId, count); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getFloat(rowId + i); - } - return array; - } - - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { - ensureAccessible(rowId); return accessor.getDouble(rowId); } - @Override - public double[] getDoubles(int rowId, int count) { - ensureAccessible(rowId, count); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getDouble(rowId + i); - } - return array; - } - - // - // APIs dealing with Arrays - // - - @Override - public int getArrayLength(int rowId) { - ensureAccessible(rowId); - return accessor.getArrayLength(rowId); - } - - @Override - public int getArrayOffset(int rowId) { - ensureAccessible(rowId); - return accessor.getArrayOffset(rowId); - } - - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { - ensureAccessible(rowId); + if (isNullAt(rowId)) return null; return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { - ensureAccessible(rowId); + if (isNullAt(rowId)) return null; return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { - ensureAccessible(rowId); + if (isNullAt(rowId)) return null; return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override - public ArrowColumnVector arrayData() { return childColumns[0]; } + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getArray(rowId); + } - /** - * Returns the ordinal's child data column. - */ @Override - public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); @@ -302,11 +162,8 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - - childColumns = new ArrowColumnVector[1]; - childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; + } else if (vector instanceof NullableMapVector) { + NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); childColumns = new ArrowColumnVector[mapVector.size()]; @@ -331,10 +188,6 @@ boolean isNullAt(int rowId) { return vector.isNull(rowId); } - final int getValueCount() { - return vector.getValueCount(); - } - final int getNullCount() { return vector.getNullCount(); } @@ -383,11 +236,7 @@ byte[] getBinary(int rowId) { throw new UnsupportedOperationException(); } - int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - int getArrayOffset(int rowId) { + ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } } @@ -529,9 +378,10 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return UTF8String.fromAddress(null, + return new UTF8String(new OffHeapMemoryBlock( stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); + stringResult.end - stringResult.start + )); } } } @@ -584,10 +434,12 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; + private final ArrowColumnVector arrayData; ArrayAccessor(ListVector vector) { super(vector); this.accessor = vector; + this.arrayData = new ArrowColumnVector(vector.getDataVector()); } @Override @@ -601,19 +453,26 @@ final boolean isNullAt(int rowId) { } @Override - final int getArrayLength(int rowId) { - return accessor.getInnerValueCountAt(rowId); - } - - @Override - final int getArrayOffset(int rowId) { - return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); + final ColumnarArray getArray(int rowId) { + ArrowBuf offsets = accessor.getOffsetBuffer(); + int index = rowId * accessor.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + return new ColumnarArray(arrayData, start, end - start); } } + /** + * Any call to "get" method will throw UnsupportedOperationException. + * + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. + * + */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(MapVector vector) { + StructAccessor(NullableMapVector vector) { super(vector); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java new file mode 100644 index 000000000000..ad99b450a480 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. + * + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. + * + * Spark only calls specific `get` method according to the data type of this {@link ColumnVector}, + * e.g. if it's int type, Spark is guaranteed to only call {@link #getInt(int)} or + * {@link #getInts(int, int)}. + * + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. Please refer to {@link #getStruct(int)}, + * {@link #getArray(int)} and {@link #getMap(int)} for the details about how to implement nested + * types. + * + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. + */ +@InterfaceStability.Evolving +public abstract class ColumnVector implements AutoCloseable { + + /** + * Returns the data type of this column vector. + */ + public final DataType dataType() { return type; } + + /** + * Cleans up memory for this column vector. The column vector is not usable after this. + * + * This overwrites `AutoCloseable.close` to remove the `throws` clause, as column vector is + * in-memory and we don't expect any exception to happen during closing. + */ + @Override + public abstract void close(); + + /** + * Returns true if this column vector contains any null values. + */ + public abstract boolean hasNull(); + + /** + * Returns the number of nulls in this column vector. + */ + public abstract int numNulls(); + + /** + * Returns whether the value at rowId is NULL. + */ + public abstract boolean isNullAt(int rowId); + + /** + * Returns the boolean type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract boolean getBoolean(int rowId); + + /** + * Gets boolean type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } + + /** + * Returns the byte type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract byte getByte(int rowId); + + /** + * Gets byte type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } + + /** + * Returns the short type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract short getShort(int rowId); + + /** + * Gets short type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } + + /** + * Returns the int type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract int getInt(int rowId); + + /** + * Gets int type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } + + /** + * Returns the long type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract long getLong(int rowId); + + /** + * Gets long type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } + + /** + * Returns the float type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract float getFloat(int rowId); + + /** + * Gets float type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } + + /** + * Returns the double type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. + */ + public abstract double getDouble(int rowId); + + /** + * Gets double type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. + */ + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } + + /** + * Returns the struct type value for rowId. If the slot for rowId is null, it should return null. + * + * To support struct type, implementations must implement {@link #getChild(int)} and make this + * vector a tree structure. The number of child vectors must be same as the number of fields of + * the struct type, and each child vector is responsible to store the data for its corresponding + * struct field. + */ + public final ColumnarRow getStruct(int rowId) { + if (isNullAt(rowId)) return null; + return new ColumnarRow(this, rowId); + } + + /** + * Returns the array type value for rowId. If the slot for rowId is null, it should return null. + * + * To support array type, implementations must construct an {@link ColumnarArray} and return it in + * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all + * the elements of all the arrays in this vector, and an offset and length which points to a range + * in that {@link ColumnVector}, and the range represents the array for rowId. Implementations + * are free to decide where to put the data vector and offsets and lengths. For example, we can + * use the first child vector as the data vector, and store offsets and lengths in 2 int arrays in + * this vector. + */ + public abstract ColumnarArray getArray(int rowId); + + /** + * Returns the map type value for rowId. If the slot for rowId is null, it should return null. + * + * In Spark, map type value is basically a key data array and a value data array. A key from the + * key array with a index and a value from the value array with the same index contribute to + * an entry of this map type value. + * + * To support map type, implementations must construct a {@link ColumnarMap} and return it in + * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all + * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data + * of all the values of all the maps in this vector, and a pair of offset and length which + * specify the range of the key/value array that belongs to the map type value at rowId. + */ + public abstract ColumnarMap getMap(int ordinal); + + /** + * Returns the decimal type value for rowId. If the slot for rowId is null, it should return null. + */ + public abstract Decimal getDecimal(int rowId, int precision, int scale); + + /** + * Returns the string type value for rowId. If the slot for rowId is null, it should return null. + * Note that the returned UTF8String may point to the data of this column vector, please copy it + * if you want to keep it after this column vector is freed. + */ + public abstract UTF8String getUTF8String(int rowId); + + /** + * Returns the binary type value for rowId. If the slot for rowId is null, it should return null. + */ + public abstract byte[] getBinary(int rowId); + + /** + * Returns the calendar interval type value for rowId. If the slot for rowId is null, it should + * return null. + * + * In Spark, calendar interval type value is basically an integer value representing the number of + * months in this interval, and a long value representing the number of microseconds in this + * interval. An interval type vector is the same as a struct type vector with 2 fields: `months` + * and `microseconds`. + * + * To support interval type, implementations must implement {@link #getChild(int)} and define 2 + * child vectors: the first child vector is an int type vector, containing all the month values of + * all the interval values in this vector. The second child vector is a long type vector, + * containing all the microsecond values of all the interval values in this vector. + */ + public final CalendarInterval getInterval(int rowId) { + if (isNullAt(rowId)) return null; + final int months = getChild(0).getInt(rowId); + final long microseconds = getChild(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + /** + * @return child [[ColumnVector]] at the given ordinal. + */ + protected abstract ColumnVector getChild(int ordinal); + + /** + * Data type for this column. + */ + protected DataType type; + + /** + * Sets up the data type of this column vector. + */ + protected ColumnVector(DataType type) { + this.type = type; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec..72a192d089b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,18 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). @@ -33,7 +33,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; @@ -134,9 +134,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); + return data.getInterval(offset + ordinal); } @Override @@ -150,8 +148,8 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa67972..d206c1df42ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,32 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ +@InterfaceStability.Evolving public final class ColumnarBatch { - public static final int DEFAULT_BATCH_SIZE = 4 * 1024; - - private final StructType schema; - private final int capacity; private int numRows; private final ColumnVector[] columns; @@ -57,7 +46,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,22 +76,9 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { - assert(numRows <= this.capacity); this.numRows = numRows; } @@ -116,16 +92,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the schema that makes up this batch. - */ - public StructType schema() { return schema; } - - /** - * Returns the max capacity (in number of rows) for this batch. - */ - public int capacity() { return capacity; } - /** * Returns the column at `ordinal`. */ @@ -140,10 +106,8 @@ public InternalRow getRow(int rowId) { return row; } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { - this.schema = schema; + public ColumnarBatch(ColumnVector[] columns) { this.columns = columns; - this.capacity = capacity; this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java new file mode 100644 index 000000000000..35648e386c4f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.MapData; + +/** + * Map abstraction in {@link ColumnVector}. + */ +public final class ColumnarMap extends MapData { + private final ColumnarArray keys; + private final ColumnarArray values; + private final int length; + + public ColumnarMap(ColumnVector keys, ColumnVector values, int offset, int length) { + this.length = length; + this.keys = new ColumnarArray(keys, offset, length); + this.values = new ColumnarArray(values, offset, length); + } + + @Override + public int numElements() { return length; } + + @Override + public ColumnarArray keyArray() { + return keys; + } + + @Override + public ColumnarArray valueArray() { + return values; + } + + @Override + public ColumnarMap copy() { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 70% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c..f2f227959002 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,27 +14,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarRow extends InternalRow { // The data for this row. - // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; @@ -54,7 +54,7 @@ public InternalRow copy() { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = data.getChildColumn(i).dataType(); + DataType dt = data.getChild(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -94,70 +94,62 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChild(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChild(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } + public byte getByte(int ordinal) { return data.getChild(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } + public short getShort(int ordinal) { return data.getChild(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } + public int getInt(int ordinal) { return data.getChild(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } + public long getLong(int ordinal) { return data.getChild(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChild(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChild(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); + return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getUTF8String(rowId); + return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getBinary(rowId); + return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); - final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return data.getChild(ordinal).getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getStruct(rowId); + return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getArray(rowId); + return data.getChild(ordinal).getArray(rowId); } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return data.getChild(ordinal).getMap(rowId); } @Override diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 6cdfe2fae564..1b37905543b4 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,5 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 92988680871a..ad0efbae8983 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1083,10 +1083,10 @@ class Column(val expr: Expression) extends Logging { * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. - * df.sort(df("age").asc_nulls_last) + * df.sort(df("age").asc_nulls_first) * * // Java - * df.sort(df.col("age").asc_nulls_last()); + * df.sort(df.col("age").asc_nulls_first()); * }}} * * @group expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f3..53f44888ebaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,8 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD @@ -34,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -74,6 +76,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * infer the input schema automatically from data. By specifying the schema here, the underlying * data source can skip the schema inference step, and thus speed up data loading. * + * {{{ + * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") + * }}} + * * @since 2.3.0 */ def schema(schemaString: String): DataFrameReader = { @@ -167,7 +173,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` + // force invocation of `load(...varargs...)` + option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) } /** @@ -185,46 +192,36 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - val options = new DataSourceV2Options((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) - - val reader = (ds, userSpecifiedSchema) match { - case (ds: ReadSupportWithSchema, Some(schema)) => - ds.createReader(schema, options) - - case (ds: ReadSupport, None) => - ds.createReader(options) - - case (ds: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $ds.") - - case (ds: ReadSupport, Some(schema)) => - val reader = ds.createReader(options) - if (reader.readSchema() != schema) { - throw new AnalysisException(s"$ds does not allow user-specified schemas.") - } - reader - - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + val ds = cls.newInstance().asInstanceOf[DataSourceV2] + if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + } + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( + ds, extraOptions.toMap ++ sessionOptions + pathsOption, + userSpecifiedSchema = userSpecifiedSchema)) + } else { + loadV1Source(paths: _*) } - - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. @@ -354,12 +351,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
      • + *
      • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
      • *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • *
      @@ -375,6 +372,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • + *
    • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
    • * * * @since 2.0.0 @@ -535,6 +539,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `header` (default `false`): uses the first line as names of columns.
    • *
    • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
    • + *
    • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
    • *
    • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
    • *
    • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing @@ -559,12 +564,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
      • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
      • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • *
      @@ -654,14 +661,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * You can set the following text-specific option(s) for reading text files: - *
        - *
      • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
      • - *
      - * By default, each line in the text files is a new row in the resulting DataFrame. - * - * Usage example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.read.text("/path/to/spark/README.md") @@ -670,6 +670,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * + * You can set the following text-specific option(s) for reading text files: + *
        + *
      • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
      • + *
      • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
      • + *
      + * * @param paths input paths * @since 1.6.0 */ @@ -693,11 +701,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * You can set the following textFile-specific option(s) for reading text files: - *
        - *
      • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
      • - *
      * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: @@ -707,6 +710,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().textFile("/path/to/spark/README.md") * }}} * + * You can set the following textFile-specific option(s) for reading text files: + *
        + *
      • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
      • + *
      • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
      • + *
      + * * @param paths input path * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e105..e183fa6f9542 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -174,7 +174,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * predicates on the partitioned columns. In order for partitioning to work well, the number * of distinct values in each column should typically be less than tens of thousands. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 1.4.0 */ @@ -188,7 +189,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Buckets the output by the given columns. If specified, the output is laid out on the file * system similar to Hive's bucketing scheme. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ @@ -202,7 +204,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Sorts the output in each bucket by the given columns. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ @@ -240,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val ds = cls.newInstance() ds match { case ws: WriteSupport => - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = df.sparkSession.sessionState.conf)).asJava) @@ -255,17 +258,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } @@ -304,7 +314,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (partitioningColumns.isDefined) { throw new AnalysisException( "insertInto() can't be used together with partitionBy(). " + - "Partition columns have already be defined for the table. " + + "Partition columns have already been defined for the table. " + "It is not necessary to use partitionBy()." ) } @@ -508,6 +518,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • * * * @since 1.4.0 @@ -577,6 +590,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 77e571272920..cd4def71e6f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,11 @@ import org.apache.spark.util.Utils private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { - new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + // Eagerly bind the encoder so we verify that the encoder matches the underlying + // schema. The user will get an error if this is not the case. + dataset.deserializer + dataset } def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { @@ -204,7 +208,7 @@ class Dataset[T] private[sql]( // The deserializer expression which can be used to build a projection and turn rows to objects // of type T, after collecting rows to the driver side. - private val deserializer = + private lazy val deserializer = exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag @@ -237,13 +241,20 @@ class Dataset[T] private[sql]( private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val takeResult = toDF().take(numRows + 1) + val newDf = toDF() + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } + } + val takeResult = newDf.select(castCols: _*).take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - lazy val timeZone = - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) - // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -252,12 +263,6 @@ class Dataset[T] private[sql]( val str = cell match { case null => "null" case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case d: Date => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case ts: Timestamp => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { @@ -1193,7 +1198,7 @@ class Dataset[T] private[sql]( def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1219,7 +1224,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1239,7 +1244,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name specified as a regex and return it as [[Column]]. + * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel * @since 2.3.0 */ @@ -1902,7 +1907,7 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. + * This is equivalent to `EXCEPT DISTINCT` in SQL. * * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. @@ -2145,6 +2150,9 @@ class Dataset[T] private[sql]( * Returns a new Dataset by adding a column or replacing the existing column that has * the same name. * + * `column`'s expression must only refer to attributes supplied by this Dataset. It is an + * error to add a column that refers to some other Dataset. + * * @group untypedrel * @since 2.0.0 */ @@ -2728,7 +2736,7 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all rows in this Dataset. + * Returns an iterator that contains all rows in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * @@ -2817,6 +2825,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 @@ -2840,6 +2849,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 @@ -2923,7 +2933,7 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } @@ -3179,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3191,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collectAsArrowToPython", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3301,14 +3312,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a009c00b0abc..6c2be3610ae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -74,7 +73,7 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } @@ -450,8 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - "Must pass a group map udf") + require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 272eb844226d..c502e583a55c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} /** @@ -81,6 +81,9 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => + // The call site where this SparkSession was constructed. + private val creationSite: CallSite = Utils.getCallSite() + private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) } @@ -742,7 +745,10 @@ class SparkSession private( private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.mapPartitions { iter => + val fromJava = python.EvaluatePython.makeFromJava(schema) + iter.map(r => fromJava(r).asInstanceOf[InternalRow]) + } internalCreateDataFrame(rowRdd, schema) } @@ -760,7 +766,7 @@ class SparkSession private( @InterfaceStability.Stable -object SparkSession { +object SparkSession extends Logging { /** * Builder for [[SparkSession]]. @@ -948,7 +954,8 @@ object SparkSession { session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } - defaultSession.set(session) + setDefaultSession(session) + setActiveSession(session) // Register a successfully instantiated context to the singleton. This should be at the // end of the class definition so that the singleton is updated only if there is no @@ -1024,6 +1031,17 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: SparkSession = { + getActiveSession.getOrElse(getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// @@ -1075,4 +1093,20 @@ object SparkSession { } } + private[spark] def cleanupAnyExistingSession(): Unit = { + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + logWarning( + s"""An existing Spark session exists as the active or default session. + |This probably means another suite leaked it. Attempting to stop it before continuing. + |This existing Spark session was created at: + | + |${session.get.creationSite.longForm} + | + """.stripMargin) + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e4..f94baef39dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b05fe49a6ac3..93bf91e56f1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -99,7 +99,7 @@ class CacheManager extends Logging { sparkSession.sessionState.conf.columnBatchSize, storageLevel, sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, - planToCache.stats) + planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -119,7 +119,7 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } @@ -138,17 +138,15 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + cd.cachedRepresentation.cacheBuilder.clearCache() // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, - statsOfPlanToCache = cd.plan.stats) + cacheBuilder = cd.cachedRepresentation + .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), + logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } } @@ -169,6 +167,13 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { + // Do not lookup the cache by hint node. Hint node is special, we should ignore it when + // canonicalizing plans, so that plans which are same except hint can hit the same cache. + // However, we also want to keep the hint info after cache lookup. Here we skip the hint + // node, so that the returned caching plan won't replace the hint node and drop the hint info + // from the original plan. + case hint: ResolvedHint => hint + case currentFragment => lookupCachedData(currentFragment) .map(_.cachedRepresentation.withOutput(currentFragment.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292b..fc3dbc1c5591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,21 +17,24 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** - * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { def vectorTypes: Option[Seq[String]] = None + protected def supportsBatch: Boolean = true + + protected def needsUnsafeRowConversion: Boolean = true + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -46,20 +49,24 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) + val isNullVar = if (nullable) { + JavaCode.isNullVariable(ctx.freshName("isNull")) + } else { + FalseLiteral + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } /** @@ -71,16 +78,24 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + if (supportsBatch) { + produceBatches(ctx, input) + } else { + produceRows(ctx, input) + } + } + private def produceBatches(ctx: CodegenContext, input: String): String = { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 + val scanTimeTotalNs = + ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { @@ -137,4 +152,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { """.stripMargin } + private def produceRows(ctx: CodegenContext, input: String): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + val row = ctx.freshName("row") + + ctx.INPUT_ROW = row + ctx.currentVars = null + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + val inputRow = if (needsUnsafeRowConversion) null else row + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, outputVars, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d1ff82c7c06b..08ff33afbba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -90,16 +90,15 @@ case class RowDataSourceScanExec( Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => + val numOutputRows = longMetric("numOutputRows") + + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) - iter.map(proj) - } - - val numOutputRows = longMetric("numOutputRows") - unsafeRow.map { r => - numOutputRows += 1 - r + iter.map( r => { + numOutputRows += 1 + proj(r) + }) } } @@ -164,13 +163,15 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( + override val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled - } else { - false + override val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + } else { + false + } } override def vectorTypes: Option[Seq[String]] = @@ -322,57 +323,30 @@ case class FileSourceScanExec( // in the case of fallback, this batched scan should never fail because of: // 1) only primitive types are supported // 2) the number of columns should be smaller than spark.sql.codegen.maxFields - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { - val unsafeRows = { - val scan = inputRDD - if (needsUnsafeRowConversion) { - scan.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map(proj) - } - } else { - scan - } - } val numOutputRows = longMetric("numOutputRows") - unsafeRows.map { r => - numOutputRows += 1 - r + + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map( r => { + numOutputRows += 1 + proj(r) + }) + } + } else { + inputRDD.map { r => + numOutputRows += 1 + r + } } } } override val nodeNamePrefix: String = "File" - override protected def doProduce(ctx: CodegenContext): String = { - if (supportsBatch) { - return super.doProduce(ctx) - } - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - // Always provide `outputVars`, so that the framework can help us build unsafe row if the input - // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. - val outputVars = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - val inputRow = if (needsUnsafeRowConversion) null else row - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, outputVars, inputRow).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } - /** * Create an RDD for bucketed reads. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f3555508185f..be50a1571a2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,7 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93ec..e4812f3d338f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -154,9 +154,13 @@ case class ExpandExec( val value = ctx.freshName("value") val code = s""" |boolean $isNull = true; - |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + |${CodeGenerator.javaType(firstExpr.dataType)} $value = + | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode( + code, + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, firstExpr.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0c2c4a1a9100..f40c50df74cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types._ /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -170,9 +170,11 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode( + JavaCode.isNullExpression(s"$index == -1"), + JavaCode.variable(index, IntegerType))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode(FalseLiteral, JavaCode.variable(index, IntegerType))) } } else { Seq.empty @@ -305,19 +307,19 @@ case class GenerateExec( nullable: Boolean, initialChecks: Seq[String]): ExprCode = { val value = ctx.freshName(name) - val javaType = ctx.javaType(dt) - val getter = ctx.getValue(source, dt, index) + val javaType = CodeGenerator.javaType(dt) + val getter = CodeGenerator.getValue(source, dt, index) val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = s""" |boolean $isNull = ${checks.mkString(" || ")}; - |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c..448eb703eacd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 18f6f697bc85..3ca03ab2939a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.execution +import java.util.Locale + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -46,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(partAttrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -64,7 +72,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic }) } if (isAllDistinctAgg) { - a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation))) + a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters))) } else { a } @@ -80,8 +88,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val partColumns = partitionColumnNames.map(_.toLowerCase).toSet - relation.output.filter(a => partColumns.contains(a.name.toLowerCase)) + val attrMap = relation.output.map(a => a.name.toLowerCase(Locale.ROOT) -> a).toMap + partitionColumnNames.map { colName => + attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), + throw new AnalysisException(s"Unable to find the column `$colName` " + + s"given [${relation.output.map(_.name).mkString(", ")}]") + ) + } } /** @@ -90,13 +103,23 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic */ private def replaceTableScanWithPartitionMetadata( child: LogicalPlan, - relation: LogicalPlan): LogicalPlan = { + relation: LogicalPlan, + partFilters: Seq[Expression]): LogicalPlan = { + // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the + // relation's schema. PartitionedRelation ensures that the filters only reference partition cols + val normalizedFilters = partFilters.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + child transform { case plan if plan eq relation => relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(Nil, Nil) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => @@ -105,7 +128,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) - val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => + val partitions = if (partFilters.nonEmpty) { + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) + } else { + catalog.listPartitions(relation.tableMeta.identifier) + } + + val partitionData = partitions.map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) @@ -122,34 +151,23 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a * pair of the partition attributes and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. */ - object PartitionedRelation { - - def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => - val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - Some((AttributeSet(partAttrs), l)) + object PartitionedRelation extends PredicateHelper { - case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) - Some((AttributeSet(partAttrs), relation)) + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { + plan match { + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) + if fsRelation.partitionSchema.nonEmpty => + val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) + Some((partAttrs, l)) - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, relation) => - if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, relation) => - if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None - } + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = AttributeSet( + getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) + Some((partAttrs, relation)) - case _ => None + case _ => None + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8bfe3eff0c3b..15379a0663f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -44,19 +44,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner - def assertAnalyzed(): Unit = { - // Analyzer is invoked outside the try block to avoid calling it again from within the - // catch block below. - analyzed - try { - sparkSession.sessionState.analyzer.checkAnalysis(analyzed) - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae - } - } + def assertAnalyzed(): Unit = analyzed def assertSupported(): Unit = { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { @@ -66,7 +54,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } lazy val withCachedData: LogicalPlan = { @@ -167,6 +155,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -190,6 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes.contains(tpe) => other.toString } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala index 16806c620635..cffd97baea6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -17,4 +17,5 @@ package org.apache.spark.sql.execution -class QueryExecutionException(message: String) extends Exception(message) +class QueryExecutionException(message: String, cause: Throwable = null) + extends Exception(message, cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bd..2c5102b1e5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -88,7 +88,7 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ef1bb1c2a446..0dc16ba5ce28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -84,7 +84,7 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( + val sorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { @@ -133,7 +133,8 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 787c1cfbfb3d..398758a3331b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! - /** Specifies any partition requirements on the input data for this operator. */ + /** + * Specifies the data distribution requirements of all the children for this operator. By default + * it's [[UnspecifiedDistribution]] for each child, which means each child can have any + * distribution. + * + * If an operator overwrites this method, and specifies distribution requirements(excluding + * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark + * guarantees that the outputs of these children will have same number of partitions, so that the + * operator can safely zip partitions of these children's result RDDs. Some operators can leverage + * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify + * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d) + * for its right child, then it's guaranteed that left and right child are co-partitioned by + * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g., + * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child. + */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) @@ -337,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1) + val left = n - buf.size + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972..4828fa60a7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -327,7 +327,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[DescribeTableCommand]] logical plan. + * Create a [[DescribeColumnCommand]] or [[DescribeTableCommand]] logical commands. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { val isExtended = ctx.EXTENDED != null || ctx.FORMATTED != null @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c31..82b4eb9fba24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -90,23 +91,58 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Select the proper physical plan for join based on joining keys and size of logical plan. * * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the - * predicates can be evaluated by matching join keys. If found, Join implementations are chosen + * predicates can be evaluated by matching join keys. If found, join implementations are chosen * with the following precedence: * - * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the - * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). - * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller - * estimated physical size. If neither one of the sides has the broadcast hint, - * we only broadcast the join side if its estimated physical size that is smaller than - * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. + * - Broadcast hash join (BHJ): + * BHJ is not supported for full outer join. For right outer join, we only can broadcast the + * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin, + * we only can broadcast the right side. For inner like join, we can broadcast both sides. + * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is + * small. However, broadcasting tables is a network-intensive operation. It could cause OOM + * or perform worse than the other join algorithms, especially when the build/broadcast side + * is big. + * + * For the supported cases, users can specify the broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and + * which join side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type + * is inner like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used. + * * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. + * * - Sort merge: if the matching join keys are sortable. * * If there is no joining keys, Join implementations are chosen with the following precedence: - * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted - * - CartesianProduct: for Inner join - * - BroadcastNestedLoopJoin + * - BroadcastNestedLoopJoin (BNLJ): + * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios: + * For right outer join, the left side is broadcast. For left outer, left semi, left anti + * and the internal join type ExistenceJoin, the right side is broadcast. For inner like + * joins, either side is broadcast. + * + * Like BHJ, users still can specify the broadcast hint and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for + * inner-like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used. + * + * - CartesianProduct: for inner like join, CartesianProduct is the fallback option. + * + * - BroadcastNestedLoopJoin (BNLJ): + * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast + * side with the broadcast hint. If neither side has a hint, we broadcast the side with + * the smaller estimated physical size. */ object JoinSelection extends Strategy with PredicateHelper { @@ -139,8 +175,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true - case j: ExistenceJoin => true + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } @@ -243,7 +278,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ - // Pick BroadcastNestedLoopJoin if one side could be broadcasted + // Pick BroadcastNestedLoopJoin if one side could be broadcast case j @ logical.Join(left, right, joinType, condition) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) @@ -288,9 +323,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") + } + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, - aggregateExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, planLater(child)) @@ -333,8 +373,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -363,6 +405,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateOperator + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + + case PhysicalAggregation(_, _, _, _) => + // If cannot match the two cases above, then it's an error + throw new AnalysisException( + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 065954559e48..828b51fa199d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.execution import java.util.Locale +import java.util.function.Supplier + +import scala.collection.mutable import org.apache.spark.broadcast import org.apache.spark.rdd.RDD @@ -58,7 +61,7 @@ trait CodegenSupport extends SparkPlan { } /** - * Whether this SparkPlan support whole stage codegen or not. + * Whether this SparkPlan supports whole stage codegen or not. */ def supportCodegen: Boolean = true @@ -106,6 +109,31 @@ trait CodegenSupport extends SparkPlan { */ protected def doProduce(ctx: CodegenContext): String + private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { + if (row != null) { + ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) + } else { + if (colVars.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + val evaluateInputs = evaluateVariables(colVars) + // generate the code to create a UnsafeRow + ctx.INPUT_ROW = row + ctx.currentVars = colVars + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val code = s""" + |$evaluateInputs + |${ev.code.trim} + """.stripMargin.trim + ExprCode(code, FalseLiteral, ev.value) + } else { + // There are no columns + ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow])) + } + } + } + /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. * @@ -126,28 +154,7 @@ trait CodegenSupport extends SparkPlan { } } - val rowVar = if (row != null) { - ExprCode("", "false", row) - } else { - if (outputVars.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - val evaluateInputs = evaluateVariables(outputVars) - // generate the code to create a UnsafeRow - ctx.INPUT_ROW = row - ctx.currentVars = outputVars - val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" - |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim - ExprCode(code, "false", ev.value) - } else { - // There is no columns - ExprCode("", "false", "unsafeRow") - } - } + val rowVar = prepareRowVar(ctx, row, outputVars) // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to @@ -156,13 +163,97 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) + + // Under certain conditions, we can put the logic to consume the rows of this operator into + // another function. So we can prevent a generated function too long to be optimized by JIT. + // The conditions: + // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled. + // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses + // all variables in output (see `requireAllOutput`). + // 3. The number of output variables must less than maximum number of parameters in Java method + // declaration. + val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator + val requireAllOutput = output.forall(parent.usedInputs.contains(_)) + val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput + && CodeGenerator.isValidParamLength(paramLength)) { + constructDoConsumeFunction(ctx, inputVars, row) + } else { + parent.doConsume(ctx, inputVars, rowVar) + } s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} |$evaluated - |${parent.doConsume(ctx, inputVars, rowVar)} + |$consumeFunc + """.stripMargin + } + + /** + * To prevent concatenated function growing too long to be optimized by JIT. We can separate the + * parent's `doConsume` codes of a `CodegenSupport` operator into a function to call. + */ + private def constructDoConsumeFunction( + ctx: CodegenContext, + inputVars: Seq[ExprCode], + row: String): String = { + val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) + val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) + + val doConsume = ctx.freshName("doConsume") + ctx.currentVars = inputVarsInFunc + ctx.INPUT_ROW = null + + val doConsumeFuncName = ctx.addNewFunction(doConsume, + s""" + | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException { + | ${parent.doConsume(ctx, inputVarsInFunc, rowVar)} + | } + """.stripMargin) + + s""" + | $doConsumeFuncName(${args.mkString(", ")}); """.stripMargin } + /** + * Returns arguments for calling method and method definition parameters of the consume function. + * And also returns the list of `ExprCode` for the parameters. + */ + private def constructConsumeParameters( + ctx: CodegenContext, + attributes: Seq[Attribute], + variables: Seq[ExprCode], + row: String): (Seq[String], Seq[String], Seq[ExprCode]) = { + val arguments = mutable.ArrayBuffer[String]() + val parameters = mutable.ArrayBuffer[String]() + val paramVars = mutable.ArrayBuffer[ExprCode]() + + if (row != null) { + arguments += row + parameters += s"InternalRow $row" + } + + variables.zipWithIndex.foreach { case (ev, i) => + val paramName = ctx.freshName(s"expr_$i") + val paramType = CodeGenerator.javaType(attributes(i).dataType) + + arguments += ev.value + parameters += s"$paramType $paramName" + val paramIsNull = if (!attributes(i).nullable) { + // Use constant `false` without passing `isNull` for non-nullable variable. + FalseLiteral + } else { + val isNull = ctx.freshName(s"exprIsNull_$i") + arguments += ev.isNull + parameters += s"boolean $isNull" + JavaCode.isNullVariable(isNull) + } + + paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType)) + } + (arguments, parameters, paramVars) + } + /** * Returns source code to evaluate all the variables, and clear the code of them, to prevent * them to be evaluated twice. @@ -325,6 +416,58 @@ object WholeStageCodegenExec { } } +object WholeStageCodegenId { + // codegenStageId: ID for codegen stages within a query plan. + // It does not affect equality, nor does it participate in destructuring pattern matching + // of WholeStageCodegenExec. + // + // This ID is used to help differentiate between codegen stages. It is included as a part + // of the explain output for physical plans, e.g. + // + // == Physical Plan == + // *(5) SortMergeJoin [x#3L], [y#9L], Inner + // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(x#3L, 200) + // : +- *(1) Project [(id#0L % 2) AS x#3L] + // : +- *(1) Filter isnotnull((id#0L % 2)) + // : +- *(1) Range (0, 5, step=1, splits=8) + // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 + // +- Exchange hashpartitioning(y#9L, 200) + // +- *(3) Project [(id#6L % 2) AS y#9L] + // +- *(3) Filter isnotnull((id#6L % 2)) + // +- *(3) Range (0, 5, step=1, splits=8) + // + // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the + // same codegen stage. + // + // The codegen stage ID is also optionally included in the name of the generated classes as + // a suffix, so that it's easier to associate a generated class back to the physical operator. + // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName + // + // The ID is also included in various log messages. + // + // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". + // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. + // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. + // + // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object + // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec + // failed to generate/compile code. + + private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] { + override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+ + }) + + def resetPerQuery(): Unit = codegenStageCounter.set(1) + + def getNextStageId(): Int = { + val counter = codegenStageCounter + val id = counter.get() + counter.set(id + 1) + id + } +} + /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -353,7 +496,8 @@ object WholeStageCodegenExec { * `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input, * used to generated code for [[BoundReference]]. */ -case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) + extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -365,6 +509,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) + def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) { + s"GeneratedIteratorForCodegenStage$codegenStageId" + } else { + "GeneratedIterator" + } + /** * Generates code for this subtree. * @@ -382,19 +532,25 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } """, inlineToOuterClass = true) + val className = generatedClassName() + val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new $className(references); } - ${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")} - final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + ${ctx.registerComment( + s"""Codegend pipeline for stage (id=$codegenStageId) + |${this.treeString.trim}""".stripMargin, + "wsc_codegenPipeline")} + ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} + final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public $className(Object[] references) { this.references = references; } @@ -427,7 +583,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message - logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") + logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() } @@ -436,7 +592,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logInfo(s"Found too long generated codes and JIT optimization might not work: " + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + - s"for this plan. To avoid this, you can raise the limit " + + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") child match { // The fallback solution of batch file source scan still uses WholeStageCodegenExec @@ -514,10 +670,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ") } override def needStopCheck: Boolean = true + + override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) } @@ -568,13 +726,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => - WholeStageCodegenExec(insertInputAdapter(plan)) + WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId()) case other => other.withNewChildren(other.children.map(insertWholeStageCodegen)) } def apply(plan: SparkPlan): SparkPlan = { if (conf.wholeStageEnabled) { + WholeStageCodegenId.resetPerQuery() insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a..6a8ec4f722ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -186,15 +186,18 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } val initBufVar = evaluateVariables(bufVars) @@ -532,7 +535,7 @@ case class HashAggregateExec( */ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = - (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) @@ -565,7 +568,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -752,12 +755,15 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { @@ -832,7 +838,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -855,7 +861,7 @@ case class HashAggregateExec( } val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( + CodeGenerator.updateColumn( fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4ab..de2d630de3fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ /** @@ -41,20 +41,23 @@ abstract class HashMapGenerator( val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ") val buffVars: Seq[ExprCode] = { val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index fd25707dd4ca..d5508275c48c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.types._ /** @@ -114,7 +114,7 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", key.dataType, ordinal.toString()), key.name)})""" }.mkString(" && ") } @@ -147,7 +147,7 @@ class RowBasedHashMapGenerator( case t: DecimalType => s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" case t: DataType => - if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) { throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") } s"agg_rowWriter.write(${ordinal}, ${key.name})" @@ -165,18 +165,14 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d6..7b3580cecc60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' @@ -93,7 +94,7 @@ class VectorizedHashMapGenerator( | | public $generatedClassName() { | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); - | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); + | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = @@ -126,8 +127,9 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", - key.dataType), key.name)})""" + val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, + "buckets[idx]") + s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } @@ -181,14 +183,14 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, - buffVars(ordinal), nullable = true) + CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", + key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc41243026..7487564ed64d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils @@ -175,7 +175,7 @@ private[sql] object ArrowConverters { new ArrowColumnVector(vector).asInstanceOf[ColumnVector] }.toArray - val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch.rowIterator().asScala } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a15a8d11aa2a..1edfdc888afd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = FalseLiteral } ev } @@ -364,11 +364,11 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number") + val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task @@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") + val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") + val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4f28eeb725cb..2d699e8a9d08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { - case t if ctx.isPrimitiveType(dt) => + case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" @@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841..da35a4734e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,26 +24,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - statsOfPlanToCache: Statistics): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = statsOfPlanToCache) -} - - /** * CachedBatch is a cached batch of rows. * @@ -54,42 +42,41 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) - extends logical.LeafNode with MultiInstanceRelation { - - override protected def innerChildren: Seq[SparkPlan] = Seq(child) + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override def producedAttributes: AttributeSet = outputSet - - @transient val partitionStatistics = new PartitionStatistics(output) + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - override def computeStats(): Statistics = { - if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache - statsOfPlanToCache - } else { - Statistics(sizeInBytes = batchStats.value.longValue) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } } + _cachedColumnBuffers } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } + } } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -122,7 +109,7 @@ case class InMemoryRelation( rowCount += 1 } - batchStats.add(totalSize) + sizeInBytesStats.add(totalSize) val stats = InternalRow.fromSeq( columnBuilders.flatMap(_.columnStats.collectedStatistics)) @@ -137,31 +124,77 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats, statsOfPlanToCache) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - batchStats, - statsOfPlanToCache).asInstanceOf[this.type] + cacheBuilder)( + statsOfPlanToCache, + outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b1285..0b4dd76c7d86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,9 +24,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( @@ -37,6 +38,11 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + override def vectorTypes: Option[Seq[String]] = Option(Seq.fill(attributes.length)( if (!conf.offHeapColumnVectorEnabled) { @@ -48,18 +54,23 @@ case class InMemoryTableScanExec( /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. - * If false, get data from UnsafeRow build from ColumnVector + * If false, get data from UnsafeRow build from CachedBatch */ - override val supportCodegen: Boolean = { + override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields - relation.schema.fields.forall(f => f.dataType match { + conf.cacheVectorizedReaderEnabled && relation.schema.fields.forall(f => f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + // TODO: revisit this. Shall we always turn off whole stage codegen if the output data are rows? + override def supportCodegen: Boolean = supportsBatch + + override protected def needsUnsafeRowConversion: Boolean = false + private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -67,19 +78,20 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } - val columnarBatch = new ColumnarBatch( - columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) columnarBatch.setNumRows(rowCount) - for (i <- 0 until attributes.length) { + for (i <- attributes.indices) { ColumnAccessor.decompress( cachedColumnarBatch.buffers(columnIndices(i)), columnarBatch.column(i).asInstanceOf[WritableColumnVector], @@ -89,20 +101,65 @@ case class InMemoryTableScanExec( columnarBatch } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - assert(supportCodegen) + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled + if (supportsBatch) { + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] + } else { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) + } + columnarIterator + } + } } + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -111,22 +168,24 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { + relation.cachedPlan.outputPartitioning match { case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + case _ => relation.cachedPlan.outputPartitioning } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Keeps relation's partition statistics because we don't serialize relation. + private val stats = relation.partitionStatistics + private def statsFor(a: Attribute) = stats.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - @transient val buildFilter: PartialFunction[Expression, Expression] = { + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -166,14 +225,14 @@ case class InMemoryTableScanExec( l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } - val partitionFilters: Seq[Expression] = { + lazy val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = filter.map( BindReferences.bindReference( _, - relation.partitionStatistics.schema, + stats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -184,7 +243,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulators: Boolean = + lazy val enableAccumulatorsForTest: Boolean = sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes @@ -196,9 +255,9 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = relation.partitionStatistics.schema + val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( @@ -229,43 +288,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - // Using these variables here to avoid serialization of entire objects (if referenced directly) - // within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulators) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulators && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + if (supportsBatch) { + WholeStageCodegenExec(this)(codegenStageId = 0).execute() + } else { + inputRDD } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 79dcf3a6105c..00a1d54b4170 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -116,7 +116,7 @@ private[columnar] case object PassThrough extends CompressionScheme { while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos - assert(len * unitSize < Int.MaxValue) + assert(len * unitSize.toLong < Int.MaxValue) putFunction(columnVector, pos, bufferPos, len) bufferPos += len * unitSize pos += len diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1122522ccb4c..640e01336aa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** @@ -64,12 +66,12 @@ case class AnalyzeColumnCommand( /** * Compute stats for the given columns. - * @return (row count, map from column name to ColumnStats) + * @return (row count, map from column name to CatalogColumnStats) */ private def computeColumnStats( sparkSession: SparkSession, tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan @@ -81,7 +83,7 @@ case class AnalyzeColumnCommand( // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => - if (!ColumnStat.supportsType(attr.dataType)) { + if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") @@ -103,7 +105,7 @@ case class AnalyzeColumnCommand( // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) + attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -111,9 +113,9 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, - attributePercentiles.get(attr))) + // according to `statExprs`, the stats struct always have 7 fields. + (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) }.toMap (rowCount, columnStats) } @@ -124,7 +126,7 @@ case class AnalyzeColumnCommand( sparkSession: SparkSession, relation: LogicalPlan): AttributeMap[ArrayData] = { val attrsToGenHistogram = if (conf.histogramEnabled) { - attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) } else { Nil } @@ -154,4 +156,120 @@ case class AnalyzeColumnCommand( AttributeMap(attributePercentiles.toSeq) } + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case BinaryType | StringType => + // For string and binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 1a0d67fc71fb..c27048626c8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -116,8 +116,8 @@ object CommandUtils extends Logging { oldStats: Option[CatalogStatistics], newTotalSize: BigInt, newRowCount: Option[BigInt]): Option[CatalogStatistics] = { - val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L) - val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + val oldTotalSize = oldStats.map(_.sizeInBytes).getOrElse(BigInt(-1)) + val oldRowCount = oldStats.flatMap(_.rowCount).getOrElse(BigInt(-1)) var newStats: Option[CatalogStatistics] = None if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a..e11dbd201004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} @@ -45,15 +44,7 @@ trait DataWritingCommand extends Command { // Output columns of the analyzed input query plan def outputColumns: Seq[Attribute] - lazy val metrics: Map[String, SQLMetric] = { - val sparkContext = SparkContext.getActive.get - Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") - ) - } + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7477d025dfe8..3c900be839aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -91,8 +91,8 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm if (sparkSession.conf.get(CATALOG_IMPLEMENTATION.key).equals("hive") && key.startsWith("hive.")) { logWarning(s"'SET $key=$value' might not work, since Spark doesn't support changing " + - "the Hive config dynamically. Please passing the Hive-specific config by adding the " + - s"prefix spark.hadoop (e.g., spark.hadoop.$key) when starting a Spark application. " + + "the Hive config dynamically. Please pass the Hive-specific config by adding the " + + s"prefix spark.hadoop (e.g. spark.hadoop.$key) when starting a Spark application. " + "For details, see the link: https://spark.apache.org/docs/latest/configuration.html#" + "dynamically-loading-spark-properties.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 306f43dc4214..f6ef433f2ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,7 +21,9 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -136,12 +138,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, - query: LogicalPlan) - extends RunnableCommand { - - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + query: LogicalPlan, + outputColumns: Seq[Attribute]) + extends DataWritingCommand { - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) assert(table.provider.isDefined) @@ -163,24 +164,25 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) + sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) - + sparkSession.sessionState.catalog.validateTableLocation(table) val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { Some(sessionState.catalog.defaultTablePath(table.identifier)) } else { table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) + sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) - sessionState.catalog.createTable(newTable, ignoreIfExists = false) + // Table location is already validated. No need to check it again during table creation. + sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && @@ -198,10 +200,10 @@ case class CreateDataSourceTableAsSelectCommand( session: SparkSession, table: CatalogTable, tableLocation: Option[URI], - data: LogicalPlan, + physicalPlan: SparkPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { - // Create the relation based on the input logical plan: `data`. + // Create the relation based on the input logical plan: `query`. val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, @@ -212,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query) + dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0142f17ce62e..bf4d96fa18d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -314,8 +314,8 @@ case class AlterTableChangeColumnCommand( val resolver = sparkSession.sessionState.conf.resolver DDLUtils.verifyAlterTableType(catalog, table, isView = false) - // Find the origin column from schema by column name. - val originColumn = findColumnByName(table.schema, columnName, resolver) + // Find the origin column from dataSchema by column name. + val originColumn = findColumnByName(table.dataSchema, columnName, resolver) // Throw an AnalysisException if the column name/dataType is changed. if (!columnEqual(originColumn, newColumn, resolver)) { throw new AnalysisException( @@ -324,7 +324,7 @@ case class AlterTableChangeColumnCommand( s"'${newColumn.name}' with type '${newColumn.dataType}'") } - val newSchema = table.schema.fields.map { field => + val newDataSchema = table.dataSchema.fields.map { field => if (field.name == originColumn.name) { // Create a new column from the origin column with the new comment. addComment(field, newColumn.getComment) @@ -332,8 +332,7 @@ case class AlterTableChangeColumnCommand( field } } - val newTable = table.copy(schema = StructType(newSchema)) - catalog.alterTable(newTable) + catalog.alterTableDataSchema(tableName, StructType(newDataSchema)) Seq.empty[Row] } @@ -345,7 +344,8 @@ case class AlterTableChangeColumnCommand( schema.fields.collectFirst { case field if resolver(field.name, name) => field }.getOrElse(throw new AnalysisException( - s"Invalid column reference '$name', table schema is '${schema}'")) + s"Can't find column `$name` given table data columns " + + s"${schema.fieldNames.mkString("[`", "`, `", "`]")}")) } // Add the comment to a column, if comment is empty, return the original column. @@ -610,10 +610,10 @@ case class AlterTableRecoverPartitionsCommand( val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val hadoopConf = spark.sessionState.newHadoopConf() + val fs = root.getFileSystem(hadoopConf) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt - val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) @@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand( pathFilter: PathFilter, threshold: Int): GenMap[String, PartitionStatistics] = { if (partitionSpecsAndLocs.length > threshold) { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 4f92ffee687a..1f7808c2f8e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -40,6 +40,10 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [databaseName.]functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. + * @param replace: When true, alter the function with the specified name */ case class CreateFunctionCommand( databaseName: Option[String], @@ -47,17 +51,17 @@ case class CreateFunctionCommand( className: String, resources: Seq[FunctionResource], isTemp: Boolean, - ifNotExists: Boolean, + ignoreIfExists: Boolean, replace: Boolean) extends RunnableCommand { - if (ifNotExists && replace) { + if (ignoreIfExists && replace) { throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + " is not allowed.") } // Disallow to define a temporary function with `IF NOT EXISTS` - if (ifNotExists && isTemp) { + if (ignoreIfExists && isTemp) { throw new AnalysisException( "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.") } @@ -79,12 +83,12 @@ case class CreateFunctionCommand( // Handles `CREATE OR REPLACE FUNCTION AS ... USING ...` if (replace && catalog.functionExists(func.identifier)) { // alter the function in the metastore - catalog.alterFunction(CatalogFunction(func.identifier, className, resources)) + catalog.alterFunction(func) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. - catalog.createFunction(CatalogFunction(func.identifier, className, resources), ifNotExists) + catalog.createFunction(func, ignoreIfExists) } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e400975f1970..44749190c79e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -695,10 +695,11 @@ case class DescribeColumnCommand( // Show column stats when EXTENDED or FORMATTED is specified. buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) - buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL")) - buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) - buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) - buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("distinct_count", + cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL")) + buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL")) val histDesc = for { c <- cs hist <- c.histogram diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 9dbbe9946ee9..69c03d862391 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker( totalNumOutput += summary.numRows } - metrics("numFiles").add(numFiles) - metrics("numOutputBytes").add(totalNumBytes) - metrics("numOutputRows").add(totalNumOutput) - metrics("numParts").add(numPartitions) + metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput) + metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) } } + +object BasicWriteJobStatsTracker { + private val NUM_FILES_KEY = "numFiles" + private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes" + private val NUM_OUTPUT_ROWS_KEY = "numOutputRows" + private val NUM_PARTS_KEY = "numParts" + + def metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"), + NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4046396d0e61..a66a07673e25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -85,7 +85,7 @@ class CatalogFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( - sparkSession, rootPaths, table.storage.properties, partitionSchema = None) + sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 54549f698aca..c0df6c779d7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -45,11 +45,11 @@ object CodecStreams { } /** - * Creates an input stream from the string path and add a closure for the input stream to be + * Creates an input stream from the given path and add a closure for the input stream to be * closed on task completion. */ - def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { - val inputStream = createInputStream(config, new Path(path)) + def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { + val inputStream = createInputStream(config, path) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 25e121050427..f16d824201e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -31,18 +30,21 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils @@ -100,24 +102,6 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param fileIndex A FileIndex that will perform partition inference - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { - val resolved = fileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } - /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -137,31 +121,26 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileStatusCache the shared cache for file statuses to speed up listing + * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to, e.g., + fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { + // The operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below - lazy val tempFileIndex = { - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) + lazy val tempFileIndex = fileIndex.getOrElse { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) } + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) + tempFileIndex.partitionSchema } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -353,13 +332,7 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) - val fileCatalog = if (userSpecifiedSchema.nonEmpty) { - val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) - new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) - } else { - tempFileCatalog - } + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -381,24 +354,23 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap( - DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray - - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) - - val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && - catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes - new CatalogFileIndex( + val index = new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, Some(index)) + (index, resultDataSchema, resultPartitionSchema) } HadoopFsRelation( @@ -435,10 +407,11 @@ case class DataSource( } /** - * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. + * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]]. + * The returned command is unresolved and need to be analyzed. */ private def planForWritingFileFormat( - format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { + format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -482,9 +455,24 @@ case class DataSource( /** * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. + * + * @param mode The save mode for this writing. + * @param data The input query plan that produces the data to be written. Note that this plan + * is analyzed and optimized. + * @param outputColumns The original output columns of the input query plan. The optimizer may not + * preserve the output column's names' case, so we need this parameter + * instead of `data.output`. + * @param physicalPlan The physical plan of the input query plan. We should run the writing + * command with this physical plan instead of creating a new physical plan, + * so that the metrics can be correctly linked to the given physical plan and + * shown in the web UI. */ - def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + def writeAndRead( + mode: SaveMode, + data: LogicalPlan, + outputColumns: Seq[Attribute], + physicalPlan: SparkPlan): BaseRelation = { + if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -493,9 +481,23 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd + val cmd = planForWritingFileFormat(format, mode, data) + val resolvedPartCols = cmd.partitionColumns.map { col => + // The partition columns created in `planForWritingFileFormat` should always be + // `UnresolvedAttribute` with a single name part. + assert(col.isInstanceOf[UnresolvedAttribute]) + val unresolved = col.asInstanceOf[UnresolvedAttribute] + assert(unresolved.nameParts.length == 1) + val name = unresolved.nameParts.head + outputColumns.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } @@ -513,11 +515,46 @@ case class DataSource( case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => + DataSource.validateSchema(data.schema) planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() + allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toSeq + } } object DataSource extends Logging { @@ -531,6 +568,8 @@ object DataSource extends Logging { val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -551,7 +590,9 @@ object DataSource extends Logging { "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, - "com.databricks.spark.csv" -> csv + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } @@ -662,26 +703,25 @@ object DataSource extends Logging { } /** - * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. - * If `checkFilesExist` is `true`, also check the file existence. + * Called before writing into a FileFormat based data source to make sure the + * supplied schema is not empty. + * @param schema */ - private def checkAndGlobPathIfNecessary( - hadoopConf: Configuration, - path: String, - checkFilesExist: Boolean): Seq[Path] = { - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") + private def validateSchema(schema: StructType): Unit = { + def hasEmptySchema(schema: StructType): Boolean = { + schema.size == 0 || schema.find { + case StructField(_, b: StructType, _, _) => hasEmptySchema(b) + case _ => false + }.isDefined } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") + + + if (hasEmptySchema(schema)) { + throw new AnalysisException( + s""" + |Datasource does not support writing empty or nested empty schemas. + |Please make sure the data schema has at least one or more column(s). + """.stripMargin) } - globPath } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d94c5bbccdd8..3f41612c0806 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1..401597f96721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -190,9 +190,18 @@ object FileFormatWriter extends Logging { global = false, child = plan).execute() } - val ret = new Array[WriteTaskResult](rdd.partitions.length) + + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + + val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( - rdd, + rddWithNonEmptyPartitions, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -202,7 +211,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f..28c36b6020d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,12 +21,15 @@ import java.io.{FileNotFoundException, IOException} import scala.collection.mutable +import org.apache.parquet.io.ParquetDecodingException + import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** @@ -179,7 +182,23 @@ class FileScanRDD( currentIterator = readCurrentFile() } - hasNext + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } } else { currentFile = null InputFileBlockHolder.unset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d9..0a568d6b8adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -76,7 +76,10 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 83cf26c63a17..00a78f7343c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -30,9 +30,22 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl /** * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. + * + * @param file A part (i.e. "block") of a single file that should be read line by line. + * @param lineSeparator A line separator that should be used for each line. If the value is `None`, + * it covers `\r`, `\r\n` and `\n`. + * @param conf Hadoop configuration + * + * @note The behavior when `lineSeparator` is `None` (covering `\r`, `\r\n` and `\n`) is defined + * by [[LineRecordReader]], not within Spark. */ class HadoopFileLinesReader( - file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { + file: PartitionedFile, + lineSeparator: Option[Array[Byte]], + conf: Configuration) extends Iterator[Text] with Closeable { + + def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf) + private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -42,7 +55,13 @@ class HadoopFileLinesReader( Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val reader = new LineRecordReader() + + val reader = lineSeparator match { + case Some(sep) => new LineRecordReader(sep) + // If the line separator is `None`, it covers `\r`, `\r\n` and `\n`. + case _ => new LineRecordReader() + } + reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 89d8a85a9cbd..b2f73b7f8d1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,6 +67,9 @@ case class HadoopFsRelation( } } + // When data and partition schemas have overlapping columns, the output + // schema respects the order of the data schema for the overlapping columns, and it + // respects the data types of the partition schema. val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) @@ -82,7 +85,11 @@ case class HadoopFsRelation( } } - override def sizeInBytes: Long = location.sizeInBytes + override def sizeInBytes: Long = { + val compressionFactor = sqlContext.conf.fileCompressionFactor + (location.sizeInBytes * compressionFactor).toLong + } + override def inputFiles: Array[String] = location.inputFiles } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 318ada0ceefc..739d1f456e3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration * @param rootPathsSpecified the list of root table paths to scan (some of which might be * filtered out later) * @param parameters as set of options to control discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, rootPathsSpecified: Seq[Path], parameters: Map[String, String], - partitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( - sparkSession, parameters, partitionSchema, fileStatusCache) { + sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942..dd7ef0d15c14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e..cc8af7b92c45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType} * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], - userPartitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { import PartitioningAwareFileIndex.BASE_PATH_PARAM @@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - userPartitionSchema match { + val inferredPartitionSpec = PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + timeZoneId = timeZoneId) + userSpecifiedSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - typeInference = false, - basePaths = basePaths, - timeZoneId = timeZoneId) + val userPartitionSchema = + combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType, + Literal.create(row.get(i, dt), dt), + userPartitionSchema.fields(i).dataType, Option(timeZoneId)).eval() }: _*) } - PartitionSpec(userProvidedSchema, spec.partitions.map { part => + PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) case _ => - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - timeZoneId = timeZoneId) + inferredPartitionSpec } } @@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } + + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param spec A partition inference result + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { + val equality = sparkSession.sessionState.conf.resolver + val resolved = spec.partitionColumns.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 472bf82d3604..f9a24806953e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -407,6 +407,34 @@ object PartitioningUtils { Literal(bigDecimal) } + val dateTry = Try { + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // DateType + DateTimeUtils.getThreadLocalDateFormat.parse(raw) + // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. + // This can happen since DateFormat.parse may not use the entire text of the given string: + // so if there are extra-characters after the date, it returns correctly. + // We need to check that we can cast the raw string since we later can use Cast to get + // the partition values with the right DataType (see + // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning) + val dateValue = Cast(Literal(raw), DateType).eval() + // Disallow DateType if the cast returned null + require(dateValue != null) + Literal.create(dateValue, DateType) + } + + val timestampTry = Try { + val unescapedRaw = unescapePathName(raw) + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // TimestampType + DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw) + // SPARK-23436: see comment for date + val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval() + // Disallow TimestampType if the cast returned null + require(timestampValue != null) + Literal.create(timestampValue, TimestampType) + } + if (typeInference) { // First tries integral types Try(Literal.create(Integer.parseInt(raw), IntegerType)) @@ -415,16 +443,8 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try( - Literal.create( - DateTimeUtils.getThreadLocalTimestampFormat(timeZone) - .parse(unescapePathName(raw)).getTime * 1000L, - TimestampType))) - .orElse(Try( - Literal.create( - DateTimeUtils.millisToDays( - DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), - DateType))) + .orElse(timestampTry) + .orElse(dateTry) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { @@ -466,7 +486,8 @@ object PartitioningUtils { val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => schema.find(f => equality(f.name, col)).getOrElse { - throw new AnalysisException(s"Partition column $col not found in schema $schema") + val schemaCatalog = schema.catalogString + throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog") } }).asNullable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f..16b2367bfdd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b..39c594a9bc61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2031381dd2e1..bc1f4ab3bb05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.datasources.csv +import java.net.URI import java.nio.charset.{Charset, StandardCharsets} import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -32,7 +33,6 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -161,7 +161,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -206,7 +207,7 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, schema) @@ -218,8 +219,9 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => + val path = new Path(lines.getPath()) UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) }.take(1).headOption match { @@ -230,11 +232,12 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource( lines.getConfiguration, - lines.getPath()), + new Path(lines.getPath())), parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce1..2ec0fc605a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -150,6 +150,9 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092c..31464f1bcc68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -131,4 +134,29 @@ object CSVUtils { schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV dataset as configured by `samplingRatio`. + */ + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 7d6d7e7eef92..3d6cc30f2ba8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,6 +203,8 @@ class UnivocityParser( case _: BadRecordException => None } } + // For records with less or more tokens than the schema, tries to return partial results + // if possible. throw BadRecordException( () => getCurrentInput, () => getPartialResult(), @@ -218,6 +220,9 @@ class UnivocityParser( row } catch { case NonFatal(e) => + // For corrupted records with the number of tokens same as the schema, + // CSV reader doesn't support partial results. All fields other than the field + // configured by `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => getCurrentInput, () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7a6c0f9fed2f..1723596de1db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -32,6 +32,13 @@ import org.apache.spark.util.Utils */ object DriverRegistry extends Logging { + /** + * Load DriverManager first to avoid any race condition between + * DriverManager static initialization block and specific driver class's + * static initialization block. e.g. PhoenixDriver + */ + DriverManager.getDrivers + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty def register(className: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 8b7c2709afde..983a5f0dcade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.datasources.json import java.io.InputStream +import java.net.URI import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -30,7 +31,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ @@ -91,25 +92,36 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): Dataset[String] = { + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -119,10 +131,14 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -145,7 +161,11 @@ object MultiLineJsonDataSource extends JsonDataSource { parsedOptions: JSONOptions): StructType = { val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( @@ -167,10 +187,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -180,19 +208,22 @@ object MultiLineJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))) } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) safeParser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffa..3b04510d2969 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -151,7 +153,13 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index f7471cd7debc..1de2ca2914c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -118,6 +119,13 @@ class OrcFileFormat } } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def isSplitable( sparkSession: SparkSession, options: Map[String, String], @@ -139,6 +147,13 @@ class OrcFileFormat } } + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val capacity = sqlConf.orcVectorizedReaderBatchSize + val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) + val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -146,8 +161,14 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + isCaseSensitive, dataSchema, requiredSchema, reader, conf) if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty @@ -155,29 +176,50 @@ class OrcFileFormat val requestedColIds = requestedColIdsOrEmptyFile.get assert(requestedColIds.length == requiredSchema.length, "[BUG] requested column IDs do not match required schema") - conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, requestedColIds.filter(_ != -1).sorted.mkString(",")) - val fileSplit = - new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) - val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - - if (partitionSchema.length == 0) { - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val batchReader = new OrcColumnarBatchReader( + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initBatch( + reader.getSchema, + requestedColIds, + requiredSchema.fields, + partitionSchema, + file.partitionValues) + + iter.asInstanceOf[Iterator[InternalRow]] } else { - val joinedRow = new JoinedRow() - iter.map(value => - unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a52..0ad3862f6cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,6 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b03ee06d04a1..460194ba61c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -50,23 +51,35 @@ object OrcUtils extends Logging { paths } - def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean) + : Option[TypeDescription] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - Some(schema) + try { + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } catch { + case e: org.apache.orc.FileFormatException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $file", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $file", e) + } } } def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } @@ -80,11 +93,8 @@ object OrcUtils extends Logging { isCaseSensitive: Boolean, dataSchema: StructType, requiredSchema: StructType, - file: Path, + reader: Reader, conf: Configuration): Option[Array[Int]] = { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) val orcFieldNames = reader.getSchema.getFieldNames.asScala if (orcFieldNames.isEmpty) { // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 45bedf70f975..d1f9e11ed422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -108,8 +108,7 @@ class ParquetFileFormat ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushdowning filters. + // This metadata is useful for keeping UDTs like Vector/Matrix. ParquetWriteSupport.setSchema(dataSchema, conf) // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet @@ -322,19 +321,6 @@ class ParquetFileFormat SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - // Try to push down filters when filter push-down is enabled. - val pushed = - if (sparkSession.sessionState.conf.parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -351,12 +337,28 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetRecordFilterEnabled val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = + sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) @@ -395,16 +397,21 @@ class ParquetFileFormat ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } val taskContext = Option(TaskContext.get()) - val parquetReader = if (enableVectorizedReader) { + if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( - convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined) + convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) + val iter = new RecordReaderIterator(vectorizedReader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) if (returningBatch) { vectorizedReader.enableReturningBatches() } - vectorizedReader + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow @@ -414,18 +421,11 @@ class ParquetFileFormat } else { new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) } + val iter = new RecordReaderIterator(reader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) reader.initialize(split, hadoopAttemptContext) - reader - } - val iter = new RecordReaderIterator(parquetReader) - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) - - // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. - if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && - enableVectorizedReader) { - iter.asInstanceOf[Iterator[InternalRow]] - } else { val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 763841efbd9f..310626197a76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,17 +17,25 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.sql.Date + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources import org.apache.spark.sql.types._ /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters(pushDownDate: Boolean) { + + private def dateToDays(date: Date): SQLDate = { + DateTimeUtils.fromJavaDate(date) + } private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => @@ -50,6 +58,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if pushDownDate => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -72,6 +84,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if pushDownDate => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -91,6 +107,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if pushDownDate => + (n: String, v: Any) => FilterApi.lt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -110,6 +130,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if pushDownDate => + (n: String, v: Any) => FilterApi.ltEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -129,6 +153,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if pushDownDate => + (n: String, v: Any) => FilterApi.gt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -148,6 +176,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if pushDownDate => + (n: String, v: Any) => FilterApi.gtEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de54..f36a89a4c3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -27,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) @@ -74,4 +82,8 @@ object ParquetOptions { "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = { + shortParquetCompressionCodecNames(name).name() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index f64e079539c4..0dea767840ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.util.SchemaUtils /** - * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. + * Replaces [[UnresolvedRelation]]s if the plan is for direct query on files. */ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def maybeSQLFile(u: UnresolvedRelation): Boolean = { @@ -118,6 +118,14 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + s"`${specifiedProvider.getSimpleName}`.") } + tableDesc.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw new AnalysisException( + s"The location of the existing table ${tableIdentWithDB.quotedString} is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${tableDesc.location}`.") + case _ => + } if (query.schema.length != existingTable.schema.length) { throw new AnalysisException( @@ -178,7 +186,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(DDLPreprocessingUtils.castAndRenameQueryOutput( + newQuery, existingTable.schema.toAttributes, conf))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: @@ -316,7 +325,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -336,6 +345,8 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit s"including ${staticPartCols.size} partition column(s) having constant value(s).") } + val newQuery = DDLPreprocessingUtils.castAndRenameQueryOutput( + insert.query, expectedColumns, conf) if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( @@ -346,37 +357,11 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit """.stripMargin) } - castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) + insert.copy(query = newQuery, partition = normalizedPartSpec) } else { // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - castAndRenameChildOutput(insert, expectedColumns) - .copy(partition = partColNames.map(_ -> None).toMap) - } - } - - private def castAndRenameChildOutput( - insert: InsertIntoTable, - expectedOutput: Seq[Attribute]): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(insert.query.output).map { - case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && - expected.name == actual.name && - expected.metadata == actual.metadata) { - actual - } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - Alias(cast(actual, expected.dataType), expected.name)( - explicitMetadata = Option(expected.metadata)) - } - } - - if (newChildOutput == insert.query.output) { - insert - } else { - insert.copy(query = Project(newChildOutput, insert.query)) + insert.copy(query = newQuery, partition = partColNames.map(_ -> None).toMap) } } @@ -491,3 +476,36 @@ object PreWriteCheck extends (LogicalPlan => Unit) { } } } + +object DDLPreprocessingUtils { + + /** + * Adjusts the name and data type of the input query output columns, to match the expectation. + */ + def castAndRenameQueryOutput( + query: LogicalPlan, + expectedOutput: Seq[Attribute], + conf: SQLConf): LogicalPlan = { + val newChildOutput = expectedOutput.zip(query.output).map { + case (expected, actual) => + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { + actual + } else { + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias( + Cast(actual, expected.dataType, Option(conf.sessionLocalTimeZone)), + expected.name)(explicitMetadata = Option(expected.metadata)) + } + } + + if (newChildOutput == query.output) { + query + } else { + Project(newChildOutput, query) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index c661e9bd3b94..e93908da4353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.execution.datasources.text -import java.io.Closeable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -29,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -89,7 +86,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -113,18 +110,18 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) + readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions) } private def readToUnsafeMem( conf: Broadcast[SerializableConfiguration], requiredSchema: StructType, - wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = { + textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = { (file: PartitionedFile) => { val confValue = conf.value.value - val reader = if (!wholeTextMode) { - new HadoopFileLinesReader(file, confValue) + val reader = if (!textOptions.wholeText) { + new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) } else { new HadoopFileWholeTextReader(file, confValue) } @@ -133,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() + unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + unsafeRowWriter.getRow() } } } @@ -152,6 +146,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { class TextOutputWriter( path: String, dataSchema: StructType, + lineSeparator: Array[Byte], context: TaskAttemptContext) extends OutputWriter { @@ -162,7 +157,7 @@ class TextOutputWriter( val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) } - writer.write('\n') + writer.write(lineSeparator) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 2a661561ab51..e4e201995faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -39,9 +41,25 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep + } + + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } + val lineSeparatorInWrite: Array[Byte] = + lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[text] object TextOptions { +private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" + val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala new file mode 100644 index 000000000000..017a6737161a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Partitioning} + +/** + * An adapter from public data source partitioning to catalyst internal `Partitioning`. + */ +class DataSourcePartitioning( + partitioning: Partitioning, + colNames: AttributeMap[String]) extends physical.Partitioning { + + override val numPartitions: Int = partitioning.numPartitions() + + override def satisfies(required: physical.Distribution): Boolean = { + super.satisfies(required) || { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) + + case _ => false + } + } + } + + private def isCandidate(clustering: Seq[Expression]): Boolean = { + clustering.forall { + case a: Attribute => colNames.contains(a) + case _ => false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5f30be5ed4af..f85971be394b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -18,30 +18,30 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) extends Partition with Serializable -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readerFactories: Seq[DataReaderFactory[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[UnsafeRow] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): UnsafeRow = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -63,6 +63,6 @@ class DataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 6093df26630c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The full output of the data source reader, without column pruning. - */ - def fullOutput: Seq[AttributeReference] - - /** - * The held data source reader. - */ - def reader: DataSourceV2Reader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(fullOutput, reader.getClass, reader.readSchema(), filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } - - lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => - fullOutput.find(_.name == name).get - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7eb99a645001..90fb5a14c9fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,15 +17,129 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + source: DataSourceV2, + options: Map[String, String], + projection: Seq[AttributeReference], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + + import DataSourceV2Relation._ + + override def simpleString: String = "RelationV2 " + metadataString + + override lazy val schema: StructType = reader.readSchema() + + override lazy val output: Seq[AttributeReference] = { + // use the projection attributes to avoid assigning new ids. fields that are not projected + // will be assigned new ids, which is okay because they are not projected. + val attrMap = projection.map(a => a.name -> a).toMap + schema.map(f => attrMap.getOrElse(f.name, + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + } + + private lazy val v2Options: DataSourceOptions = makeV2Options(options) - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + // postScanFilters: filters that need to be evaluated after the scan. + // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. + // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. + lazy val ( + reader: DataSourceReader, + postScanFilters: Seq[Expression], + pushedFilters: Seq[Expression]) = { + val newReader = userSpecifiedSchema match { + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + + DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) + + val (postScanFilters, pushedFilters) = filters match { + case Some(filterSeq) => + DataSourceV2Relation.pushFilters(newReader, filterSeq) + case _ => + (Nil, Nil) + } + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + + (newReader, postScanFilters, pushedFilters) + } + + override def doCanonicalize(): LogicalPlan = { + val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + + // override output with canonicalized output to avoid attempting to configure a reader + val canonicalOutput: Seq[AttributeReference] = this.output + .map(a => QueryPlan.normalizeExprId(a, projection)) + + new DataSourceV2Relation(c.source, c.options, c.projection) { + override lazy val output: Seq[AttributeReference] = canonicalOutput + } + } + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } + + override def newInstance(): DataSourceV2Relation = { + // projection is used to maintain id assignment. + // if projection is not set, use output so the copy is not equal to the original + copy(projection = projection.map(_.newInstance())) + } +} + +/** + * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true. + * + * Note that, this plan has a mutable reader, so Spark won't apply operator push-down for this plan, + * to avoid making the plan mutable. We should consolidate this plan and [[DataSourceV2Relation]] + * after we figure out how to apply operator push-down for streaming data sources. + */ +case class StreamingDataSourceV2Relation( + output: Seq[AttributeReference], + source: DataSourceV2, + options: Map[String, String], + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + + override def isStreaming: Boolean = true + + override def simpleString: String = "Streaming RelationV2 " + metadataString + + override def pushedFilters: Seq[Expression] = Nil + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: StreamingDataSourceV2Relation => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => @@ -36,7 +150,111 @@ case class DataSourceV2Relation( } object DataSourceV2Relation { - def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + private implicit class SourceHelpers(source: DataSourceV2) { + def asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _: ReadSupportWithSchema => + // this method is only called if there is no user-supplied schema. if there is no + // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. + throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def asReadSupportWithSchema: ReadSupportWithSchema = { + source match { + case support: ReadSupportWithSchema => + support + case _: ReadSupport => + throw new AnalysisException( + s"Data source does not support user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def name: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + } + + private def makeV2Options(options: Map[String, String]): DataSourceOptions = { + new DataSourceOptions(options.asJava) + } + + private def schema( + source: DataSourceV2, + v2Options: DataSourceOptions, + userSchema: Option[StructType]): StructType = { + val reader = userSchema match { + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + reader.readSchema() + } + + def create( + source: DataSourceV2, + options: Map[String, String], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) + } + + private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + reader match { + case projectionSupport: SupportsPushDownRequiredColumns => + projectionSupport.pruneColumns(struct) + case _ => + } + } + + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (postScanFilters, pushedFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + + (untranslatableExprs ++ postScanFilters, pushedFilters) + + case _ => (filters, Nil) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 49c506bc560c..77cb707340b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,65 +24,122 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + output: Seq[AttributeReference], + @transient source: DataSourceV2, + @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], + @transient reader: DataSourceReader) + extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { + + override def simpleString: String = "ScanV2 " + metadataString + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2ScanExec => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } - override def references: AttributeSet = AttributeSet.empty + override def outputPartitioning: physical.Partitioning = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + SinglePartition - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + SinglePartition - override protected def doExecute(): RDD[InternalRow] = { - val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() - case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] - }.asJava - } + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + SinglePartition - val inputRDD = reader match { - case _: ContinuousReader => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + case _ => super.outputPartitioning + } - case _ => - new DataSourceRDD(sparkContext, readTasks) - } + private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala + case _ => + reader.createDataReaderFactories().asScala.map { + new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] + } + } + + private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + r.createBatchDataReaderFactories().asScala + } + + private lazy val inputRDD: RDD[InternalRow] = reader match { + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetReaderPartitions(readerFactories.size)) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + readerFactories) + .asInstanceOf[RDD[InternalRow]] + + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + + case _ => + new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[InternalRow]].map { r => - numOutputRows += 1 - r + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } + + override protected def needsUnsafeRowConversion: Boolean = false + + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this)(codegenStageId = 0).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } } } } -class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) - extends ReadTask[UnsafeRow] { +class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) + extends DataReaderFactory[UnsafeRow] { - override def preferredLocations: Array[String] = rowReadTask.preferredLocations + override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations override def createDataReader: DataReader[UnsafeRow] = { new RowToUnsafeDataReader( - rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f5..1b7c639f10f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,15 +20,22 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + + case r: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala new file mode 100644 index 000000000000..693e67dcd108 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.util.Utils + +/** + * A trait that can be used by data source v2 related query plans(both logical and physical), to + * provide a string format of the data source information for explain. + */ +trait DataSourceV2StringFormat { + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The options for this data source reader. + */ + def options: Map[String, String] + + /** + * The filters which have been pushed to the data source. + */ + def pushedFilters: Seq[Expression] + + private def sourceName: String = source match { + case registered: DataSourceRegister => registered.shortName() + case _ => source.getClass.getSimpleName.stripSuffix("$") + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") + } + + // TODO: we should only display some standard options like path, table, etc. + if (options.nonEmpty) { + entries += "Options" -> Utils.redact(options).map { + case (k, v) => s"$k=$v" + }.mkString("[", ",", "]") + } + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) + }, " (", ", ", ")") + } else { + "" + } + + s"$sourceName$outputStr$entriesStr" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index df034adf1e7d..9293d4f831bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,120 +17,55 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} -import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.v2.reader._ -/** - * Pushes down various operators to the underlying data source for better performance. Operators are - * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you - * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the - * data source should execute FILTER before LIMIT. And required columns are calculated at the end, - * because when more operators are pushed down, we may need less columns at Spark side. - */ -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may - // appear in many places for column pruning. - // TODO: Ideally column pruning should be implemented via a plan property that is propagated - // top-down, then we can simplify the logic here and only collect target operators. - val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(candidates.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => candidates - } - - val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) - if (withFilter.output == fields) { - withFilter - } else { - Project(fields, withFilter) - } - } - - // TODO: add more push down rules. - - // TODO: nested fields pruning - def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { - plan match { - case Project(projectList, child) => - val required = projectList.filter(requiredByParent.contains).flatMap(_.references) - pushDownRequiredColumns(child, required) - - case Filter(condition, child) => - val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) - - case DataSourceV2Relation(fullOutput, reader) => reader match { - case r: SupportsPushDownRequiredColumns => - // Match original case of attributes. - val attrMap = AttributeMap(fullOutput.zip(fullOutput)) - val requiredColumns = requiredByParent.map(attrMap) - r.pruneColumns(requiredColumns.toStructType) - case _ => - } - - // TODO: there may be more operators can be used to calculate required columns, we can add - // more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { + override def apply( + plan: LogicalPlan): LogicalPlan = plan transformUp { + // PhysicalOperation guarantees that filters are deterministic; no need to check + case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => + // merge the filters + val filters = relation.filters match { + case Some(existing) => + existing ++ newFilters + case _ => + newFilters } - } - pushDownRequiredColumns(filterPushed, filterPushed.output) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) - } - - /** - * Finds a Filter node(with an optional Project child) above data source relation. - */ - object FilterAndProject { - // returns the project list, the filter condition and the data source relation. - def unapply(plan: LogicalPlan) - : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + val projectAttrs = project.map(_.toAttribute) + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(projectAttrs) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + projectAttrs + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } - case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + val newRelation = relation.copy( + projection = projection.asInstanceOf[Seq[AttributeReference]], + filters = Some(filters)) - case Filter(condition, Project(fields, r: DataSourceV2Relation)) - if fields.forall(_.deterministic) => - val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) - val substituted = condition.transform { - case a: Attribute => attributeMap.getOrElse(a, a) - } - Some((fields, substituted, r)) + // Add a Filter for any filters that need to be evaluated after scan. + val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) + val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) - case _ => None - } + // Add a Project to ensure the output matches the required projection + if (newRelation.output != projectAttrs) { + Project(project, filtered) + } else { + filtered + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a8..ea283ed77efd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.util.control.NonFatal + import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -26,17 +29,17 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** * The logical plan for writing data into data source v2. */ -case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -44,7 +47,7 @@ case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) e /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil @@ -54,6 +57,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } + val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -61,32 +65,21 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) s"The input RDD has ${messages.length} partitions.") try { - val runTask = writer match { - case w: ContinuousWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter) - } - sparkContext.runJob( rdd, - runTask, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, - (index, message: WriterCommitMessage) => messages(index) = message + (index, message: WriterCommitMessage) => { + messages(index) = message + writer.onDataWriterCommit(message) + } ) logInfo(s"Data source writer $writer is committing.") writer.commit(messages) logInfo(s"Data source writer $writer committed.") } catch { - case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -98,7 +91,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) throw new SparkException("Writing job failed.", cause) } logError(s"Data source writer $writer aborted.") - throw new SparkException("Writing job aborted.", cause) + cause match { + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } } sparkContext.emptyRDD @@ -109,70 +106,59 @@ object DataWritingSparkTask extends Logging { def run( writeTask: DataWriterFactory[InternalRow], context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean): WriterCommitMessage = { + val stageId = context.stageId() + val partId = context.partitionId() + val attemptId = context.attemptNumber() + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") + val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { iter.foreach(dataWriter.write) - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + if (commitAuthorized) { + logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + dataWriter.commit() + } else { + val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + logInfo(message) + // throwing CommitDeniedException will trigger the catch block for abort + throw new CommitDeniedException(message, stageId, partId, attemptId) + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + msg + })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for partition ${context.partitionId()} is aborting.") + logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") + logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - iter.foreach(dataWriter.write) - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer - logError(s"Writer for partition ${context.partitionId()} is aborting.") - dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } } class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber), + rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c8e236be28b4..e3d28388c547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -46,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - * @param requiredDistribution The distribution that is required by the operator - * @param numPartitions Used when the distribution doesn't require a specific number of partitions - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering, desiredPartitions) => - HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - /** * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. @@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // shuffle data when we have more than one children because data generated by // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + val supportsDistribution = requiredChildDistributions.forall { dist => + dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] + } children.length > 1 && supportsDistribution } @@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // // It will be great to introduce a new Partitioning to represent the post-shuffle // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } @@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - // Ensure that the operator's children satisfy their output distribution requirements: + // Ensure that the operator's children satisfy their output distribution requirements. children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + val numPartitions = distribution.requiredNumPartitions + .getOrElse(defaultNumPreShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { - case UnspecifiedDistribution => false - case BroadcastDistribution(_) => false + // Get the indexes of children which have specified distribution requirements and need to have + // same number of partitions. + val childrenIndexes = requiredChildDistributions.zipWithIndex.filter { + case (UnspecifiedDistribution, _) => false + case (_: BroadcastDistribution, _) => false case _ => true - } - if (children.length > 1 - && requiredChildDistributions.exists(requireCompatiblePartitioning) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + }.map(_._2) - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + val childrenNumPartitions = + childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet + + if (childrenNumPartitions.size > 1) { + // Get the number of partitions which is explicitly required by the distributions. + val requiredNumPartitions = { + val numPartitionsSet = childrenIndexes.flatMap { + index => requiredChildDistributions(index).requiredNumPartitions + }.toSet + assert(numPartitionsSet.size <= 1, + s"$operator have incompatible requirements of the number of partitions for its children") + numPartitionsSet.headOption } - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } + val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) - children.zip(requiredChildDistributions).map { - case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) - case _ => ShuffleExchangeExec(targetPartitioning, child) - } + children = children.zip(requiredChildDistributions).zipWithIndex.map { + case ((child, distribution), index) if childrenIndexes.contains(index) => + if (child.outputPartitioning.numPartitions == targetNumPartitions) { + child + } else { + val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) + child match { + // If child is an exchange, we replace it with a new one having defaultPartitioning. + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) + case _ => ShuffleExchangeExec(defaultPartitioning, child) + } } - } + + case ((child, _), _) => child } } @@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchangeExec(partitioning, child, _) => - child.children match { - case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator + // TODO: remove this after we create a physical operator for `RepartitionByExpression`. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator } case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5a1e217082bc..b89203719541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.exchange import java.util.Random +import java.util.function.Supplier import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -25,13 +26,15 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** * Performs a shuffle that will result in the desired `newPartitioning`. @@ -150,12 +153,9 @@ object ShuffleExchangeExec { * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. * * @param partitioner the partitioner for the shuffle - * @param serializer the serializer that will be used to write rows * @return true if rows should be copied before being shuffled, false otherwise */ - private def needToCopyObjectsBeforeShuffle( - partitioner: Partitioner, - serializer: Serializer): Boolean = { + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { // Note: even though we only use the partitioner's `numPartitions` field, we require it to be // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output @@ -164,22 +164,24 @@ object ShuffleExchangeExec { val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + val numParts = partitioner.numPartitions if (sortBasedShuffleOn) { - val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + if (numParts <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializer.supportsRelocationOfSerializedObjects) { + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records // prior to sorting them. This optimization is only applied in cases where shuffle // dependency does not specify an aggregator or ordering and the record serializer has - // certain properties. If this optimization is enabled, we can safely avoid the copy. + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. // - // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only - // need to check whether the optimization is enabled and supported by our serializer. + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. false } else { // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must @@ -247,14 +249,61 @@ object ShuffleExchangeExec { case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. + val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.numPartitions > 1 && + newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + canUseRadixSort) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + if (needToCopyObjectsBeforeShuffle(part)) { + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 1918fcc5482d..6fa716d9fade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,12 +22,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, LongType} import org.apache.spark.util.TaskCompletionListener /** @@ -182,16 +182,17 @@ case class BroadcastHashJoinExec( // the variables are needed even there is no matched rows val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) val code = s""" |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { | ${ev.code} | $isNull = ${ev.isNull}; | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) } } } @@ -486,7 +487,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode.forNonNullValue( + JavaCode.variable(existsVar, BooleanType))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 66e8031bb519..897a4dae39f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -46,7 +46,7 @@ case class ShuffledHashJoinExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 94405410cce9..d8261f0f33b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,11 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, -ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -78,7 +77,7 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. @@ -516,9 +515,9 @@ case class SortMergeJoinExec( ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - val javaType = ctx.javaType(a.dataType) - val defaultValue = ctx.defaultValue(a.dataType) + val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val javaType = CodeGenerator.javaType(a.dataType) + val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = @@ -531,11 +530,12 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), + leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } }.unzip } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cccee63bc068..66bcda891373 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils @@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false + val stopEarly = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d1bd8a707686..03d1bbf2ab88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -456,7 +456,7 @@ case class CoGroupExec( right: SparkPlan) extends BinaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil override def requiredChildOrdering: Seq[Seq[SortOrder]] = leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala new file mode 100644 index 000000000000..8e01e8e56a5b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +/** + * Physical node for aggregation with group aggregate Pandas UDF. + * + * This plan works by sending the necessary (projected) input grouped data as Arrow record batches + * to the python worker, the python worker invokes the UDF and sends the results to the executor, + * finally the executor evaluates any post-aggregation expressions and join the result with the + * grouped key. + */ +case class AggregateInPandasExec( + groupingExpressions: Seq[NamedExpression], + udfExpressions: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExpressions.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val prunedProj = UnsafeProjection.create(allInputs, child.output) + + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c06bc7b66ff3..c4de214679ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -74,15 +74,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) + val outputTypes = output.drop(child.output.length).map(_.dataType) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(batchIter, context.partitionId(), context) @@ -90,8 +89,9 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() - assert(schemaOut.equals(batch.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } else { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed353565..5fcdcddca7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** @@ -138,7 +138,7 @@ class ArrowPythonRunner( if (reader != null && batchLoaded) { batchLoaded = reader.loadNextBatch() if (batchLoaded) { - val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) batch } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 26ee25f633ea..f4d83e8dc7c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } + + val fromJava = EvaluatePython.makeFromJava(resultType) + outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => if (udfs.length == 1) { // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow(0) = fromJava(result) mutableRow } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + fromJava(result).asInstanceOf[InternalRow] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 9bbfa6018ba7..520afad28764 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -83,82 +83,134 @@ object EvaluatePython { } /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. + * Make a converter that converts `obj` to the type specified by the data type, or returns + * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c + def makeFromJava(dataType: DataType): Any => Any = dataType match { + case BooleanType => (obj: Any) => nullSafeConvert(obj) { + case b: Boolean => b + } - case (c: Byte, ByteType) => c - case (c: Short, ByteType) => c.toByte - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte + case ByteType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c + case c: Short => c.toByte + case c: Int => c.toByte + case c: Long => c.toByte + } - case (c: Byte, ShortType) => c.toShort - case (c: Short, ShortType) => c - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort + case ShortType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toShort + case c: Short => c + case c: Int => c.toShort + case c: Long => c.toShort + } - case (c: Byte, IntegerType) => c.toInt - case (c: Short, IntegerType) => c.toInt - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt + case IntegerType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toInt + case c: Short => c.toInt + case c: Int => c + case c: Long => c.toInt + } - case (c: Byte, LongType) => c.toLong - case (c: Short, LongType) => c.toLong - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c + case LongType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toLong + case c: Short => c.toLong + case c: Int => c.toLong + case c: Long => c + } - case (c: Float, FloatType) => c - case (c: Double, FloatType) => c.toFloat + case FloatType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c + case c: Double => c.toFloat + } - case (c: Float, DoubleType) => c.toDouble - case (c: Double, DoubleType) => c + case DoubleType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c.toDouble + case c: Double => c + } - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) { + case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale) + } - case (c: Int, DateType) => c + case DateType => (obj: Any) => nullSafeConvert(obj) { + case c: Int => c + } - case (c: Long, TimestampType) => c - // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs - case (c: Int, TimestampType) => c.toLong + case TimestampType => (obj: Any) => nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + } - case (c, StringType) => UTF8String.fromString(c.toString) + case StringType => (obj: Any) => nullSafeConvert(obj) { + case _ => UTF8String.fromString(obj.toString) + } - case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case BinaryType => (obj: Any) => nullSafeConvert(obj) { + case c: String => c.getBytes(StandardCharsets.UTF_8) + case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + } - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + case ArrayType(elementType, _) => + val elementFromJava = makeFromJava(elementType) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + (obj: Any) => nullSafeConvert(obj) { + case c: java.util.List[_] => + new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray) + case c if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e))) + } - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => - ArrayBasedMapData( - javaMap, - (key: Any) => fromJava(key, keyType), - (value: Any) => fromJava(value, valueType)) + case MapType(keyType, valueType, _) => + val keyFromJava = makeFromJava(keyType) + val valueFromJava = makeFromJava(valueType) + + (obj: Any) => nullSafeConvert(obj) { + case javaMap: java.util.Map[_, _] => + ArrayBasedMapData( + javaMap, + (key: Any) => keyFromJava(key), + (value: Any) => valueFromJava(value)) + } - case (c, StructType(fields)) if c.getClass.isArray => - val array = c.asInstanceOf[Array[_]] - if (array.length != fields.length) { - throw new IllegalStateException( - s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." - ) + case StructType(fields) => + val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray + + (obj: Any) => nullSafeConvert(obj) { + case c if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + + val row = new GenericInternalRow(fields.length) + var i = 0 + while (i < fields.length) { + row(i) = fieldsFromJava(i)(array(i)) + i += 1 + } + row } - new GenericInternalRow(array.zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) + + case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + } - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null + private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, { + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + _: Any => null + }) + } } private val module = "pyspark.sql.types" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d..9d56f4824998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or - * grouping key, evaluate them after aggregate. + * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate. */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { @@ -39,12 +39,14 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && + (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined) }.isDefined } @@ -93,7 +95,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { - e.find(_.isInstanceOf[PythonUDF]).isDefined + e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -106,12 +108,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas can be evaluated directly in python worker + // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) @@ -149,17 +151,16 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(udf => - udf.evalType == PythonEvalType.SQL_BATCHED_UDF || - udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF - ), "Can only extract scalar vectorized udf or sql batch udf") + require( + validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } val evaluation = validUdfs.partition( - _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 59db66bd7adf..513e174c7733 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) } } @@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 50dca32cb786..f4c2d02ee942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} import org.apache.spark.sql.types.DataType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala new file mode 100644 index 000000000000..606ba250ad9d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException, OutputStream} +import java.util.{EnumSet, UUID} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs} +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * An interface to abstract out all operation related to streaming checkpoints. Most importantly, + * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a + * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and + * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations + * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or + * without overwrite. + * + * This higher-level interface above the Hadoop FileSystem is necessary because + * different implementation of FileSystem/FileContext may have different combination of operations + * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename, + * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while + * keeping the usage simple (`createAtomic` -> `close` or `cancel`). + */ +trait CheckpointFileManager { + + import org.apache.spark.sql.execution.streaming.CheckpointFileManager._ + + /** + * Create a file and make its contents available atomically after the output stream is closed. + * + * @param path Path to create + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** List the files in a path that match a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** List all the files in a path. */ + def list(path: Path): Array[FileStatus] = { + list(path, new PathFilter { override def accept(path: Path): Boolean = true }) + } + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + + /** Is the default file system this implementation is operating on the local file system. */ + def isLocal: Boolean +} + +object CheckpointFileManager extends Logging { + + /** + * Additional methods in CheckpointFileManager implementations that allows + * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename + */ + sealed trait RenameHelperMethods { self => CheckpointFileManager + /** Create a file with overwrite. */ + def createTempFile(path: Path): FSDataOutputStream + + /** + * Rename a file. + * + * @param srcPath Source path to rename + * @param dstPath Destination path to rename to + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit + } + + /** + * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used + * mainly by `CheckpointFileManager.createAtomic` to write a file atomically. + * + * @see [[CheckpointFileManager]]. + */ + abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream) + extends FSDataOutputStream(underlyingStream, null) { + /** Cancel the `underlyingStream` and ensure that the output file is not generated. */ + def cancel(): Unit + } + + /** + * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing + * to a temporary file and then renames. + */ + sealed class RenameBasedFSDataOutputStream( + fm: CheckpointFileManager with RenameHelperMethods, + finalPath: Path, + tempPath: Path, + overwriteIfPossible: Boolean) + extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) { + + def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = { + this(fm, path, generateTempPath(path), overwrite) + } + + logInfo(s"Writing atomically to $finalPath using temp file $tempPath") + @volatile private var terminated = false + + override def close(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + try { + fm.renameTempFile(tempPath, finalPath, overwriteIfPossible) + } catch { + case fe: FileAlreadyExistsException => + logWarning( + s"Failed to rename temp file $tempPath to $finalPath because file exists", fe) + if (!overwriteIfPossible) throw fe + } + logInfo(s"Renamed temp file $tempPath to $finalPath") + } finally { + terminated = true + } + } + + override def cancel(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + fm.delete(tempPath) + } catch { + case NonFatal(e) => + logWarning(s"Error cancelling write to $finalPath", e) + } finally { + terminated = true + } + } + } + + + /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */ + def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = { + val fileManagerClass = hadoopConf.get( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key) + if (fileManagerClass != null) { + return Utils.classForName(fileManagerClass) + .getConstructor(classOf[Path], classOf[Configuration]) + .newInstance(path, hadoopConf) + .asInstanceOf[CheckpointFileManager] + } + try { + // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename() + // gives atomic renames, which is what we rely on for the default implementation + // `CheckpointFileManager.createAtomic`. + new FileContextBasedCheckpointFileManager(path, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning( + "Could not use FileContext API for managing Structured Streaming checkpoint files at " + + s"$path. Using FileSystem API instead for managing log files. If the implementation " + + s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" + + s"your Structured Streaming is not guaranteed.") + new FileSystemBasedCheckpointFileManager(path, hadoopConf) + } + } + + private def generateTempPath(path: Path): Path = { + val tc = org.apache.spark.TaskContext.get + val tid = if (tc != null) ".TID" + tc.taskAttemptId else "" + new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp") + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */ +class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + protected val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + fs.create(path, true) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def exists(path: Path): Boolean = { + try + return fs.getFileStatus(path) != null + catch { + case e: FileNotFoundException => + return false + } + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + if (!overwriteIfPossible && fs.exists(dstPath)) { + throw new FileAlreadyExistsException( + s"Failed to rename $srcPath to $dstPath as destination already exists") + } + + if (!fs.rename(srcPath, dstPath)) { + // FileSystem.rename() returning false is very ambiguous as it can be for many reasons. + // This tries to make a best effort attempt to return the most appropriate exception. + if (fs.exists(dstPath)) { + if (!overwriteIfPossible) { + throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists") + } + } else if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Failed to rename as $srcPath was not found") + } else { + val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false" + logWarning(msg) + throw new IOException(msg) + } + } + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + logInfo(s"Failed to delete $path as it does not exist") + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => true + case _ => false + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */ +class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + import CreateFlag._ + import Options._ + fc.create( + path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled())) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def exists(path: Path): Boolean = { + fc.util.exists(path) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + import Options.Rename._ + fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE) + } + + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fc.getDefaultFileSystem match { + case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs + case _ => false + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e9..b3d12f67b5d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -42,9 +43,11 @@ object FileStreamSink extends Logging { try { val hdfsPath = new Path(singlePath) val fs = hdfsPath.getFileSystem(hadoopConf) - val metadataPath = new Path(hdfsPath, metadataDir) - val res = fs.exists(metadataPath) - res + if (fs.isDirectory(hdfsPath)) { + fs.exists(new Path(hdfsPath, metadataDir)) + } else { + false + } } catch { case NonFatal(e) => logWarning(s"Error while looking for metadata directory.") @@ -95,6 +98,11 @@ class FileStreamSink( new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) + } + override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") @@ -129,7 +137,7 @@ class FileStreamSink( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, - statsTrackers = Nil, + statsTrackers = Seq(basicWriteJobStatsTracker), options = options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0debd7db8475..8c016abc5b64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -166,7 +166,7 @@ class FileStreamSource( val newDataSource = DataSource( sparkSession, - paths = files.map(_.path), + paths = files.map(f => new Path(new URI(f.path)).toString), userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc54107f8b8..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 6e8154d58d4c..bd0a46115ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") - import HDFSMetadataLog._ - val metadataPath = new Path(path) - protected val fileManager = createFileManager() + + protected val fileManager = + CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - writeBatch(batchId, metadata) + writeBatchToFile(metadata, batchIdToPath(batchId)) true } } - private def writeTempBatch(metadata: T): Option[Path] = { - while (true) { - val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") - try { - val output = fileManager.create(tempPath) - try { - serialize(metadata, output) - return Some(tempPath) - } finally { - output.close() - } - } catch { - case e: FileAlreadyExistsException => - // Failed to create "tempPath". There are two cases: - // 1. Someone is creating "tempPath" too. - // 2. This is a restart. "tempPath" has already been created but not moved to the final - // batch file (not committed). - // - // For both cases, the batch has not yet been committed. So we can retry it. - // - // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use - // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a - // big problem because it requires the attacker must have the permission to write the - // metadata path. In addition, the old Streaming also have this issue, people can create - // malicious checkpoint files to crash a Streaming application too. - } - } - None - } - - /** - * Write a batch to a temp file then rename it to the batch file. + /** Write a batch to a temp file then rename it to the batch file. * * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, metadata: T): Unit = { - val tempPath = writeTempBatch(metadata).getOrElse( - throw new IllegalStateException(s"Unable to create temp batch file $batchId")) + private def writeBatchToFile(metadata: T, path: Path): Unit = { + val output = fileManager.createAtomic(path, overwriteIfPossible = false) try { - // Try to commit the batch - // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") - fileManager.rename(tempPath, batchIdToPath(batchId)) - - // SPARK-17475: HDFSMetadataLog should not leak CRC files - // If the underlying filesystem didn't rename the CRC file, delete it. - val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") - if (fileManager.exists(crcPath)) fileManager.delete(crcPath) + serialize(metadata, output) + output.close() } catch { case e: FileAlreadyExistsException => - // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. - // So throw an exception to tell the user this is not a valid behavior. + output.cancel() + // If next batch file already exists, then another concurrently running query has + // written it. throw new ConcurrentModificationException( - s"Multiple HDFSMetadataLog are using $path", e) - } finally { - fileManager.delete(tempPath) - } - } - - /** - * @return the deserialized metadata in a batch file, or None if file not exist. - * @throws IllegalArgumentException when path does not point to a batch file. - */ - def get(batchFile: Path): Option[T] = { - if (fileManager.exists(batchFile)) { - if (isBatchFile(batchFile)) { - get(pathToBatchId(batchFile)) - } else { - throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!") - } - } else { - None + s"Multiple streaming queries are concurrently using $path", e) + case e: Throwable => + output.cancel() + throw e } } @@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => @@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - private def createFileManager(): FileManager = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - try { - new FileContextManager(metadataPath, hadoopConf) - } catch { - case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log files at path " + - s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + - s"inconsistent under failures.") - new FileSystemManager(metadataPath, hadoopConf) - } - } - /** * Parse the log version from the given `text` -- will throw exception when the parsed version * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", @@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: object HDFSMetadataLog { - /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ - trait FileManager { - - /** List the files in a path that matches a filter. */ - def list(path: Path, filter: PathFilter): Array[FileStatus] - - /** Make directory at the give path and all its parent directories as needed. */ - def mkdirs(path: Path): Unit - - /** Whether path exists */ - def exists(path: Path): Boolean - - /** Open a file for reading, or throw exception if it does not exist. */ - def open(path: Path): FSDataInputStream - - /** Create path, or throw exception if it already exists */ - def create(path: Path): FSDataOutputStream - - /** - * Atomically rename path, or throw exception if it cannot be done. - * Should throw FileNotFoundException if srcPath does not exist. - * Should throw FileAlreadyExistsException if destPath already exists. - */ - def rename(srcPath: Path, destPath: Path): Unit - - /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ - def delete(path: Path): Unit - } - - /** - * Default implementation of FileManager using newer FileContext API. - */ - class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fc = if (path.toUri.getScheme == null) { - FileContext.getFileContext(hadoopConf) - } else { - FileContext.getFileContext(path.toUri, hadoopConf) - } - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fc.util.listStatus(path, filter) - } - - override def rename(srcPath: Path, destPath: Path): Unit = { - fc.rename(srcPath, destPath) - } - - override def mkdirs(path: Path): Unit = { - fc.mkdir(path, FsPermission.getDirDefault, true) - } - - override def open(path: Path): FSDataInputStream = { - fc.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fc.create(path, EnumSet.of(CreateFlag.CREATE)) - } - - override def exists(path: Path): Boolean = { - fc.util().exists(path) - } - - override def delete(path: Path): Unit = { - try { - fc.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - - /** - * Implementation of FileManager using older FileSystem API. Note that this implementation - * cannot provide atomic renaming of paths, hence can lead to consistency issues. This - * should be used only as a backup option, when FileContextManager cannot be used. - */ - class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fs = path.getFileSystem(hadoopConf) - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fs.listStatus(path, filter) - } - - /** - * Rename a path. Note that this implementation is not atomic. - * @throws FileNotFoundException if source path does not exist. - * @throws FileAlreadyExistsException if destination path already exists. - * @throws IOException if renaming fails for some unknown reason. - */ - override def rename(srcPath: Path, destPath: Path): Unit = { - if (!fs.exists(srcPath)) { - throw new FileNotFoundException(s"Source path does not exist: $srcPath") - } - if (fs.exists(destPath)) { - throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") - } - if (!fs.rename(srcPath, destPath)) { - throw new IOException(s"Failed to rename $srcPath to $destPath") - } - } - - override def mkdirs(path: Path): Unit = { - fs.mkdirs(path, FsPermission.getDirDefault) - } - - override def open(path: Path): FSDataInputStream = { - fs.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fs.create(path, false) - } - - override def exists(path: Path): Boolean = { - fs.exists(path) - } - - override def delete(path: Path): Unit = { - try { - fs.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - /** * Verify if batchIds are continuous and between `startId` and `endId`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a10ed5f2df1b..c480b96626f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,7 +62,7 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } - private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) @@ -143,4 +143,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 5f0b195fcfcb..3ff5b86ac45d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} + /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends Offset { +case class LongOffset(offset: Long) extends OffsetV2 { override val json = offset.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 1da703cefd8e..5cacdd070b73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. * - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class MetadataLogFileIndex( sparkSession: SparkSession, path: Path, - userPartitionSchema: Option[StructType]) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { + userSpecifiedSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -51,7 +51,7 @@ class MetadataLogFileIndex( } override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - allFilesFromLog.toArray.groupBy(_.getPath.getParent) + allFilesFromLog.groupBy(_.getPath.getParent) } override def rootPaths: Seq[Path] = path :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index b84e6ce64c61..66b11ecddf23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} - -import scala.collection.mutable - import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} -import org.apache.spark.util.Clock +import org.apache.spark.sql.streaming.StreamingQueryProgress /** * Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to @@ -39,14 +35,17 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group - registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) - registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) - - private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0) + registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) + registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + + private def registerGauge[T]( + name: String, + f: StreamingQueryProgress => T, + default: T): Unit = { synchronized { metricRegistry.register(name, new Gauge[T] { - override def getValue: T = f() + override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default) }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 9a7a13fcc580..6709e7052f00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,14 +17,21 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} +import java.util.Optional + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -33,10 +40,11 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: Sink, + sink: BaseStreamingSink, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, + extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, @@ -44,12 +52,17 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private val watermarkTracker = new WatermarkTracker() + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -57,26 +70,54 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a + // map as we go to ensure each identical relation gets the same StreamingExecutionRelation + // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical + // plan for the data within that batch. + // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, + // since the existing logical plan has already used those attributes. The per-microbatch + // transformation is responsible for replacing attributes with their final values. + + val disabledSources = + sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + val _logicalPlan = analyzedPlan.transform { - case streamingRelation@StreamingRelation(dataSource, _, output) => + case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) + val source = dataSourceV1.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. + logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) - if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + case s @ StreamingRelationV2( + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = v1DataSource.createSource(metadataPath) + val reader = dataSourceV2.createMicroBatchReader( + Optional.empty(), // user specified schema + metadataPath, + new DataSourceOptions(options.asJava)) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. + readerToDataSourceMap(reader) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReader [$reader] from " + + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") + StreamingExecutionRelation(reader, output)(sparkSession) + }) + case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + if (v1Relation.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support microbatch processing.") + } + val source = v1Relation.get.dataSource.createSource(metadataPath) + nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(source, output)(sparkSession) }) } @@ -89,40 +130,55 @@ class MicroBatchExecution( * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (currentBatchIsRunnable) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // If the current batch has been executed, then increment the batch id, else there was + // no data to execute the batch + if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -172,6 +228,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -187,17 +244,15 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + val start = committedOffsets.get(source) + source.getBatch(start, end) case nonV1Tuple => - throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") + // The V2 API does not have the same edge case requiring getBatch to be called + // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets - constructNextBatch() } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -205,19 +260,18 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -228,79 +282,63 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - }.toMap - availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { + // If new data is already available that means this method has already been called before + // and it must have already committed the offset range of next batch to the offset log. + // Hence do nothing, just return true. + if (isNewDataAvailable) return true - if (dataAvailable) { - true - } else { - noNewData = true - false + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } - } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -317,9 +355,11 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -330,15 +370,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -346,6 +383,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -357,33 +396,54 @@ class MicroBatchExecution( s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) + Some(source -> batch.logicalPlan) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) + case v2: OffsetV2 => v2 + } + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") + + val (source, options) = reader match { + // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` + // implementation. We provide a fake one here for explain. + case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + case _ => readerToDataSourceMap.getOrElse(reader, { + FakeDataSourceV2 -> Map.empty[String, String] + }) + } + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. - val withNewSources = logicalPlan transform { + val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => - newData.get(source).map { data => - val newPlan = data.logicalPlan - assert(output.size == newPlan.output.size, + newData.get(source).map { dataPlan => + assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newPlan.output, ",")}") - replacements ++= output.zip(newPlan.output) - newPlan + s"${Utils.truncatedString(dataPlan.output, ",")}") + + val aliases = output.zip(dataPlan.output).map { case (to, from) => + Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) + } + Project(aliases, dataPlan) }.getOrElse { LocalRelation(output, isStreaming = true) } } // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) + val newAttributePlan = newBatchesPlan transformAllExpressions { case ct: CurrentTimestamp => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) @@ -392,6 +452,26 @@ class MicroBatchExecution( cd.dataType, cd.timeZoneId) } + val triggerLogicalPlan = sink match { + case _: Sink => newAttributePlan + case s: StreamWriteSupport => + val writer = s.createStreamWriter( + s"$runId", + newAttributePlan.schema, + outputMode, + new DataSourceOptions(extraOptions.asJava)) + if (writer.isInstanceOf[SupportsWriteInternalRow]) { + WriteToDataSourceV2( + new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) + } else { + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + } + case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") + } + + sparkSessionToRunBatch.sparkContext.setLocalProperty( + MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -409,16 +489,43 @@ class MicroBatchExecution( reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { - sink.addBatch(currentBatchId, nextBatch) + sink match { + case s: Sink => s.addBatch(currentBatchId, nextBatch) + case _: StreamWriteSupport => + // This doesn't accumulate any data - it just forces execution of the microbatch writer. + nextBatch.collect() + } } } + withProgressLocked { + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + awaitProgressLockCondition.signalAll() + } + watermarkTracker.updateWatermark(lastExecution.executedPlan) + logDebug(s"Completed batch ${currentBatchId}") + } + + /** Execute a function while locking the stream from making an progress */ + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() + f } finally { awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } + +object MicroBatchExecution { + val BATCH_ID_KEY = "streaming.sql.batchId" +} + +object MemoryStreamDataSource extends DataSourceV2 + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a1b63a6de382..787174481ff0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -32,14 +32,16 @@ import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PR case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { /** - * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of + * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of * sources. * * This method is typically used to associate a serialized offset with actual sources (which * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index e3f4abcf9f1d..2c8d7c7b0f3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * This class is used to log offsets to persistent files in HDFS. * Each file corresponds to a specific batch of offsets. The file - * format contain a version string in the first line, followed + * format contains a version string in the first line, followed * by a the JSON string representation of the offsets separated * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1c9043613cb6..16ad3ef9a3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -53,7 +55,7 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, DataFrame] + protected def newData: Map[BaseStreamingSource, LogicalPlan] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] @@ -141,7 +143,7 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + val sourceProgress = sources.distinct.map { source => val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -207,62 +209,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + } + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will leads to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index d02cf882b61a..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = { - new ContinuousRateStreamReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 261d69bbd984..02fed50485b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.streaming.reader.Offset { + extends v2.reader.streaming.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3e76bf7b7ca8..3fc8c7887896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -163,7 +163,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null @@ -356,25 +356,7 @@ abstract class StreamExecution( private def isInterruptedByStop(e: Throwable): Boolean = { if (state.get == TERMINATED) { - e match { - // InterruptedIOException - thrown when an I/O operation is interrupted - // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted - case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => - true - // The cause of the following exceptions may be one of the above exceptions: - // - // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as - // BiFunction.apply - // ExecutionException - thrown by codes running in a thread pool and these codes throw an - // exception - // UncheckedExecutionException - thrown by codes that cannot throw a checked - // ExecutionException, such as BiFunction.apply - case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) - if e2.getCause != null => - isInterruptedByStop(e2.getCause) - case _ => - false - } + StreamExecution.isInterruptionException(e) } else { false } @@ -418,11 +400,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +424,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ @@ -559,6 +547,26 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + + def isInterruptionException(e: Throwable): Boolean = e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptionException(e2.getCause) + case _ => + false + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala index 020c9cb4a730..3f2cdadfbaee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} /** - * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to + * Wrap non-serializable StreamExecution to make the query serializable as it's easy for it to * get captured with normal usage. It's safe to capture the query but not use it in executors. * However, if the user tries to call its methods, it will throw `IllegalStateException`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a9d50e3a112e..f02d3a2c3733 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -43,7 +42,7 @@ object StreamingRelation { * passing to [[StreamExecution]] to run a query. */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName @@ -54,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) } /** @@ -61,9 +62,9 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: Source, + source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -75,6 +76,8 @@ case class StreamingExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } // We have to pack in the V1 data source as a shim, for the case when a source implements @@ -92,14 +95,16 @@ case class StreamingRelationV2( sourceName: String, extraOptions: Map[String, String], output: Seq[Attribute], - v1DataSource: DataSource)(session: SparkSession) - extends LeafNode { + v1Relation: Option[StreamingRelation])(session: SparkSession) + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** @@ -109,7 +114,7 @@ case class ContinuousExecutionRelation( source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -121,6 +126,8 @@ case class ContinuousExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index c351f658cb95..fa7c8ee906ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c0..19e3e55cb282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 000000000000..80865669558d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan + +class WatermarkTracker extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var watermarkMs: Long = 0 + private var updated = false + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + watermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 + if (newWatermarkMs > watermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + watermarkMs = newWatermarkMs + updated = true + } else { + logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") + updated = false + } + } + + def currentWatermark: Long = synchronized { watermarkMs } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 71eaabe273fe..cfba1001c6de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,58 +17,30 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -class ConsoleSink(options: Map[String, String]) extends Sink with Logging { - // Number of rows to display, by default 20 rows - private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20) - - // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true) - - // Track the batch id - private var lastBatchId = -1L - - override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - val batchIdStr = if (batchId <= lastBatchId) { - s"Rerun batch: $batchId" - } else { - lastBatchId = batchId - s"Batch: $batchId" - } - - // scalastyle:off println - println("-------------------------------------------") - println(batchIdStr) - println("-------------------------------------------") - // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } - - override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" -} - case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { override def schema: StructType = data.schema } -class ConsoleSinkProvider extends StreamSinkProvider +class ConsoleSinkProvider extends DataSourceV2 + with StreamWriteSupport with DataSourceRegister with CreatableRelationProvider { - def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { - new ConsoleSink(parameters) + + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new ConsoleWriter(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 000000000000..0a3b9dcccb6c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.{NextIterator, ThreadUtils} + +class ContinuousDataSourceRDDPartition( + val index: Int, + val readerFactory: DataReaderFactory[UnsafeRow]) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readerFactories.zipWithIndex.map { + case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + }.toArray + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + if (partition.queueReader == null) { + partition.queueReader = + new ContinuousQueuedDataReader( + partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index d79e4bd65f56..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{SystemClock, ThreadUtils} - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() - - val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--${runId}--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - reader.close() - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2843ab13bde2..f58146ac4239 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,22 +17,23 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.SparkEnv -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2 +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} @@ -42,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: ContinuousWriteSupport, + sink: StreamWriteSupport, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -52,13 +53,13 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources - override lazy val logicalPlan: LogicalPlan = { - assert(queryExecutionThread eq Thread.currentThread, - "logicalPlan must be initialized in StreamExecutionThread " + - s"but the current thread was ${Thread.currentThread}") + // For use only in test harnesses. + private[sql] var currentEpochCoordinatorId: String = _ + + override val logicalPlan: LogicalPlan = { val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( @@ -67,7 +68,7 @@ class ContinuousExecution( ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) case StreamingRelationV2(_, sourceName, _, _, _) => - throw new AnalysisException( + throw new UnsupportedOperationException( s"Data source $sourceName does not support continuous processing.") } } @@ -78,15 +79,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +123,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -153,7 +160,7 @@ class ContinuousExecution( dataSource.createContinuousReader( java.util.Optional.empty[StructType](), metadataPath, - new DataSourceV2Options(extraReaderOptions.asJava)) + new DataSourceOptions(extraReaderOptions.asJava)) } uniqueSources = continuousSources.distinct @@ -161,7 +168,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(source, options, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -173,8 +180,8 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) - DataSourceV2Relation(newOutput, reader) + reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -187,15 +194,15 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createContinuousWriter( + val writer = sink.createStreamWriter( s"$runId", triggerLogicalPlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan) + new DataSourceOptions(extraOptions.asJava)) + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { @@ -210,28 +217,28 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - sparkSession.sparkContext.setLocalProperty( + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) - sparkSession.sparkContext.setLocalProperty( - ContinuousExecution.RUN_ID_KEY, runId.toString) + // Add another random ID on top of the run ID, to distinguish epoch coordinators across + // reconfigurations. + val epochCoordinatorId = s"$runId--${UUID.randomUUID}" + currentEpochCoordinatorId = epochCoordinatorId + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) - stopSources() + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -258,11 +265,20 @@ class ContinuousExecution( SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } + } catch { + case t: Throwable + if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING => + logInfo(s"Query $id ignoring exception from reconfiguring: $t") + // interrupted by reconfiguration - swallow exception so we can restart the query } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() epochUpdateThread.join() + + stopSources() + sparkSession.sparkContext.cancelJobGroup(runId.toString) } } @@ -273,17 +289,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } @@ -297,8 +318,10 @@ class ContinuousExecution( synchronized { if (queryExecutionThread.isAlive) { commitLog.add(epoch) - val offset = offsetLog.get(epoch).get.offsets(0).get + val offset = + continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) + continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) } else { return } @@ -346,5 +369,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" - val RUN_ID_KEY = "__run_id" + val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 000000000000..01a999f6505f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + factory: DataReaderFactory[UnsafeRow], + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + private val reader = factory.createDataReader() + + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentEpoch: Long = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): UnsafeRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + context.partitionId(), currentEpoch, currentOffset)) + currentEpoch += 1 + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[DataReader]]. + */ + class DataReaderThread extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index c9aa78a5a2e2..2f0de2612c15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -23,32 +23,31 @@ import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType -case class ContinuousRateStreamPartitionOffset( +case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class ContinuousRateStreamReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + case RateStreamPartitionOffset(i, currVal, nextRead) => (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) @@ -58,17 +57,17 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ - override def setOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -86,33 +85,60 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamReadTask( + RateStreamContinuousDataReaderFactory( start.value, start.runTimeMs, i, numPartitions, perPartitionRate) - .asInstanceOf[ReadTask[Row]] + .asInstanceOf[DataReaderFactory[Row]] }.asJava } override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } -case class RateStreamReadTask( +case class RateStreamContinuousDataReaderFactory( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ReadTask[Row] { + extends ContinuousDataReaderFactory[Row] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousDataReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) + } + override def createDataReader(): DataReader[Row] = - new RateStreamDataReader(startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) + new RateStreamContinuousDataReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamDataReader( +class RateStreamContinuousDataReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -151,5 +177,5 @@ class RateStreamDataReader( override def close(): Unit = {} override def getOffset(): PartitionOffset = - ContinuousRateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) + RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala similarity index 71% rename from core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala index b1217980faf1..e0a6f6dd50bb 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala @@ -15,17 +15,12 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.SparkException /** - * Used for shipping per-thread stacktraces from the executors to driver. + * An exception thrown when a continuous processing task runs with a nonzero attempt ID. */ -private[spark] case class ThreadStackTrace( - threadId: Long, - threadName: String, - threadState: Thread.State, - stackTrace: String, - blockedByThreadId: Option[Long], - blockedByLock: String, - holdingLocks: Seq[String]) - +class ContinuousTaskRetryException + extends SparkException("Continuous execution does not support task retry", null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 000000000000..91f157658151 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch committed.") + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a3..cc6808065c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -39,6 +36,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -70,27 +76,28 @@ private[sql] case class ReportPartitionOffset( /** Helper object used to create reference to [[EpochCoordinator]]. */ private[sql] object EpochCoordinatorRef extends Logging { - private def endpointName(runId: String) = s"EpochCoordinator-$runId" + private def endpointName(id: String) = s"EpochCoordinator-$id" /** * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, + epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( writer, reader, query, startEpoch, session, env.rpcEnv) - val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref } - def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { - val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv) logDebug("Retrieved existing EpochCoordinator endpoint") rpcEndpointRef } @@ -108,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, startEpoch: Long, @@ -116,6 +123,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +156,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +201,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 000000000000..943c731a7052 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 000000000000..e0af3a2f1b85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.util.Utils + +/** + * The physical plan for writing data into a continuous processing [[StreamWriter]]. + */ +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${rdd.partitions.length} partitions.") + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + rdd.collect() + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3041d4d703cb..6720cdd24b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,26 +17,28 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -45,15 +47,43 @@ object MemoryStream { new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) } +/** + * A base class for memory stream implementations. Supports adding data and resetting. + */ +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { + protected val encoder = encoderFor[A] + protected val attributes = encoder.schema.toAttributes + + def toDS(): Dataset[A] = { + Dataset[A](sqlContext.sparkSession, logicalPlan) + } + + def toDF(): DataFrame = { + Dataset.ofRows(sqlContext.sparkSession, logicalPlan) + } + + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + + def readSchema(): StructType = encoder.schema + + protected def logicalPlan: LogicalPlan + + def addData(data: TraversableOnce[A]): Offset +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends Source with Logging { - protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) + extends MemoryStreamBase[A](sqlContext) + with MicroBatchReader with SupportsScanUnsafeRow with Logging { + + protected val logicalPlan: LogicalPlan = + StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -61,11 +91,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[Dataset[A]] + protected val batches = new ListBuffer[Array[UnsafeRow]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + @GuardedBy("this") + protected var startOffset = new LongOffset(-1) + + @GuardedBy("this") + private var endOffset = new LongOffset(-1) + /** * Last offset that was discarded, or -1 if no commits have occurred. Note that the value * -1 is used in calculations below and isn't just an arbitrary constant. @@ -73,81 +109,68 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def schema: StructType = encoder.schema - - def toDS(): Dataset[A] = { - Dataset(sqlContext.sparkSession, logicalPlan) - } - - def toDF(): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, logicalPlan) - } - - def addData(data: A*): Offset = { - addData(data.toTraversable) - } - def addData(data: TraversableOnce[A]): Offset = { - val encoded = data.toVector.map(d => encoder.toRow(d).copy()) - val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) - val ds = Dataset[A](sqlContext.sparkSession, plan) - logDebug(s"Adding ds: $ds") + val objects = data.toSeq + val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray + logDebug(s"Adding: $objects") this.synchronized { currentOffset = currentOffset + 1 - batches += ds + batches += rows currentOffset } } override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None - } else { - Some(currentOffset) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] } } - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - // Internal buffer only holds the batches after lastCommittedOffset. - val newBlocks = synchronized { - val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 - val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 - batches.slice(sliceStart, sliceEnd) - } + override def getStartOffset: OffsetV2 = synchronized { + if (startOffset.offset == -1) null else startOffset + } - logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) + override def getEndOffset: OffsetV2 = synchronized { + if (endOffset.offset == -1) null else endOffset + } - newBlocks - .map(_.toDF()) - .reduceOption(_ union _) - .getOrElse { - sys.error("No data selected!") + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) + val startOrdinal = startOffset.offset.toInt + 1 + val endOrdinal = endOffset.offset.toInt + 1 + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") + batches.slice(sliceStart, sliceEnd) } + + logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) + + newBlocks.map { block => + new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + }.asJava + } } private def generateDebugString( - blocks: TraversableOnce[Dataset[A]], + rows: Seq[UnsafeRow], startOrdinal: Int, endOrdinal: Int): String = { - val originalUnsupportedCheck = - sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck") - try { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false") - s"MemoryBatch [$startOrdinal, $endOrdinal]: " + - s"${blocks.flatMap(_.collect()).mkString(", ")}" - } finally { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck) - } + val fromRow = encoder.resolveAndBind().fromRow _ + s"MemoryBatch [$startOrdinal, $endOrdinal]: " + + s"${rows.map(row => fromRow(row)).mkString(", ")}" } - override def commit(end: Offset): Unit = synchronized { + override def commit(end: OffsetV2): Unit = synchronized { def check(newOffset: LongOffset): Unit = { val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt @@ -170,16 +193,47 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def reset(): Unit = synchronized { batches.clear() + startOffset = LongOffset(-1) + endOffset = LongOffset(-1) currentOffset = new LongOffset(-1) lastOffsetCommitted = new LongOffset(-1) } } + +class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) + extends DataReaderFactory[UnsafeRow] { + override def createDataReader(): DataReader[UnsafeRow] = { + new DataReader[UnsafeRow] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the array. + currentIndex += 1 + currentIndex < records.length + } + + override def get(): UnsafeRow = records(currentIndex) + + override def close(): Unit = {} + } + } +} + +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -189,7 +243,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -198,6 +252,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { @@ -247,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala new file mode 100644 index 000000000000..d276403190b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.types.StructType + +/** Common methods used to create writes for the the console sink */ +class ConsoleWriter(schema: StructType, options: DataSourceOptions) + extends StreamWriter with Logging { + + // Number of rows to display, by default 20 rows + protected val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + protected val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 + // behavior. + printRows(messages, schema, s"Batch: $epochId") + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { + case PackedRowCommitMessage(rs) => rs + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(printMessage) + println("-------------------------------------------") + // scalastyle:off println + spark + .createDataFrame(rows.toList.asJava, schema) + .show(numRowsToShow, isTruncated) + } + + override def toString(): String = { + s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala new file mode 100644 index 000000000000..a8fca3c19a2d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.{Encoder, Row, SQLContext} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + +/** + * The overall strategy here is: + * * ContinuousMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * RecordEndpoint is set up as an endpoint for executor-side + * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified + * offset within the list, or null if that offset doesn't yet have a record. + */ +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) + private val NUM_PARTITIONS = 2 + + protected val logicalPlan = + StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + + // ContinuousReader implementation + + @GuardedBy("this") + private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + + private val recordEndpoint = new RecordEndpoint() + @volatile private var endpointRef: RpcEndpointRef = _ + + def addData(data: TraversableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.toSeq.zipWithIndex.map { + case (item, index) => records(index % NUM_PARTITIONS) += item + } + + // The new target offset is the offset where all records in all partitions have been processed. + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + } + + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset + } + + override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset( + offsets.map { + case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + }.toMap + ) + } + + override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + synchronized { + val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = + recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + new ContinuousMemoryStreamDataReaderFactory( + endpointName, part, index): DataReaderFactory[Row] + }.toList.asJava + } + } + + override def stop(): Unit = { + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + override def commit(end: Offset): Unit = {} + + // ContinuousReadSupport implementation + // This is necessary because of how StreamTest finds the source for AddDataMemory steps. + def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + this + } + + /** + * Endpoint for executors to poll for records. + */ + private class RecordEndpoint extends ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => + ContinuousMemoryStream.this.synchronized { + val buf = records(part) + val record = if (buf.size <= index) None else Some(buf(index)) + + context.reply(record.map(Row(_))) + } + } + } +} + +object ContinuousMemoryStream { + case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * Data reader factory for continuous memory stream. + */ +class ContinuousMemoryStreamDataReaderFactory( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends DataReaderFactory[Row] { + override def createDataReader: ContinuousMemoryStreamDataReader = + new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) +} + +/** + * Data reader for continuous memory stream. + * + * Polls the driver endpoint for new records. + */ +class ContinuousMemoryStreamDataReader( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends ContinuousDataReader[Row] { + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[Row] = None + + override def next(): Boolean = { + current = getRecord + while (current.isEmpty) { + Thread.sleep(10) + current = getRecord + } + currentOffset += 1 + true + } + + override def get(): Row = current.get + + override def close(): Unit = {} + + override def getOffset: ContinuousMemoryStreamPartitionOffset = + ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) +} + +case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) + extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(partitionNums) +} + +case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) + extends PartitionOffset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 000000000000..df5d69d57e36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala new file mode 100644 index 000000000000..56f7ff25cbed --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped + * streaming writer. + */ +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() +} + +class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) + extends DataSourceWriter with SupportsWriteInternalRow { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = + writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => throw new IllegalStateException( + "InternalRowMicroBatchWriter should only be created with base writer support") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala new file mode 100644 index 000000000000..e07355aa37db --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} + +/** + * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery + * to a [[DataSourceWriter]] on the driver. + * + * Note that, because it sends all rows to the driver, this factory will generally be unsuitable + * for production-quality sinks. It's intended for use in tests. + */ +case object PackedRowWriterFactory extends DataWriterFactory[Row] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { + new PackedRowDataWriter() + } +} + +/** + * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most + * recent interval. + */ +case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage + +/** + * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. + */ +class PackedRowDataWriter() extends DataWriter[Row] with Logging { + private val data = mutable.Buffer[Row]() + + override def write(row: Row): Unit = data.append(row) + + override def commit(): PackedRowCommitMessage = { + val msg = PackedRowCommitMessage(data.toArray) + data.clear() + msg + } + + override def abort(): Unit = data.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 000000000000..f54291bea667 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 000000000000..6bdd492f0cb3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 97bada08bcd2..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.SystemClock - -class RateStreamV2Reader(options: DataSourceV2Options) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = new SystemClock - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createReadTasks(): java.util.List[ReadTask[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs + msPerPartitionBetweenRows - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index da7c31cf6242..468313bfe8c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -27,12 +27,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -40,24 +40,13 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 - with MicroBatchWriteSupport with ContinuousWriteSupport with Logging { - - override def createMicroBatchWriter( +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { + override def createStreamWriter( queryId: String, - batchId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[DataSourceV2Writer] = { - java.util.Optional.of(new MemoryWriter(this, batchId, mode)) - } - - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[ContinuousWriter] = { - java.util.Optional.of(new ContinuousMemoryWriter(this, mode)) + options: DataSourceOptions): StreamWriter = { + new MemoryStreamWriter(this, mode) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -79,6 +68,10 @@ class MemorySinkV2 extends DataSourceV2 batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { @@ -108,7 +101,7 @@ class MemorySinkV2 extends DataSourceV2 case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -119,13 +112,13 @@ class MemorySinkV2 extends DataSourceV2 batches.clear() } - override def toString(): String = "MemorySink" + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) - extends DataSourceV2Writer with Logging { + extends DataSourceWriter with Logging { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -141,8 +134,8 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) - extends ContinuousWriter { +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) + extends StreamWriter { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -153,13 +146,16 @@ class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) sink.write(epochId, outputMode, newRows) } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } } case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } } @@ -184,10 +180,10 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 0b22cbc46e6b..90f4a5ba4234 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -15,27 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -import org.apache.spark.unsafe.types.UTF8String - -object TextSocketSource { +object TextSocketMicroBatchReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -43,12 +45,17 @@ object TextSocketSource { } /** - * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. - * This source will *not* work in production applications due to multiple reasons, including no - * support for fault recovery and keeping all of the text read in memory forever. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt @GuardedBy("this") private var socket: Socket = null @@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(String, Timestamp)] @GuardedBy("this") - protected var currentOffset: LongOffset = new LongOffset(-1) + private var currentOffset: LongOffset = LongOffset(-1L) @GuardedBy("this") - protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + private var lastOffsetCommitted: LongOffset = LongOffset(-1L) initialize() + /** This method is only used for unit test */ + private[sources] def getCurrentOffset(): LongOffset = synchronized { + currentOffset.copy() + } + private def initialize(): Unit = synchronized { socket = new Socket(host, port) val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) @@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo logWarning(s"Stream closed by $host:$port") return } - TextSocketSource.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = (line, Timestamp.valueOf( - TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - ) - currentOffset = currentOffset + 1 + TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + currentOffset += 1 batches.append(newData) } } @@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo readThread.start() } - /** Returns the schema of the data from this source */ - override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP - else TextSocketSource.SCHEMA_REGULAR + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } + + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None + override def readSchema(): StructType = { + if (options.getBoolean("includeTimestamp", false)) { + TextSocketMicroBatchReader.SCHEMA_TIMESTAMP } else { - Some(currentOffset) + TextSocketMicroBatchReader.SCHEMA_REGULAR } } - /** Returns the data that is between the offsets (`start`, `end`]. */ - override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") + + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo batches.slice(sliceStart, sliceEnd) } - val rdd = sqlContext.sparkContext - .parallelize(rawList) - .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + assert(SparkSession.getActiveSession.isDefined) + val spark = SparkSession.getActiveSession.get + val numPartitions = spark.sparkContext.defaultParallelism + + val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + rawList.zipWithIndex.foreach { case (r, idx) => + slices(idx % numPartitions).append(r) + } + + (0 until numPartitions).map { i => + val slice = slices(i) + new DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new DataReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } + + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } + + override def close(): Unit = {} + } + } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo } } - override def toString: String = s"TextSocketSource[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } -class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { - private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { - Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { - case Success(bool) => bool - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } +class TextSocketSourceProvider extends DataSourceV2 + with MicroBatchReadSupport with DataSourceRegister with Logging { - /** Returns the name and schema of the source that can be used to continually read data. */ - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { + private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!parameters.contains("host")) { + if (!params.get("host").isPresent) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!parameters.contains("port")) { + if (!params.get("port").isPresent) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - if (schema.nonEmpty) { - throw new AnalysisException("The socket source does not support a user-specified schema.") + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") } - - val sourceSchema = - if (parseIncludeTimestamp(parameters)) { - TextSocketSource.SCHEMA_TIMESTAMP - } else { - TextSocketSource.SCHEMA_REGULAR - } - ("textSocket", sourceSchema) } - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val host = parameters("host") - val port = parameters("port").toInt - new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3f5002a4e693..df722b953228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.io._ import java.nio.channels.ClosedChannelException import java.util.Locale @@ -27,13 +27,16 @@ import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} @@ -87,10 +90,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case object ABORTED extends STATE private val newVersion = version + 1 - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) @volatile private var state: STATE = UPDATING - @volatile private var finalDeltaFile: Path = null + private val finalDeltaFile: Path = deltaFile(newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream) override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId @@ -103,14 +106,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val keyCopy = key.copy() val valueCopy = value.copy() mapToUpdate.put(keyCopy, valueCopy) - writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") val prevValue = mapToUpdate.remove(key) if (prevValue != null) { - writeRemoveToDeltaFile(tempDeltaFileStream, key) + writeRemoveToDeltaFile(compressedStream, key) } } @@ -126,8 +129,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit verify(state == UPDATING, "Cannot commit after already committed or aborted") try { - finalizeDeltaFile(tempDeltaFileStream) - finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + commitUpdates(newVersion, mapToUpdate, compressedStream) state = COMMITTED logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion @@ -140,23 +142,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - try { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + if (state == UPDATING) { + state = ABORTED + cancelDeltaFile(compressedStream, deltaFileStream) + } else { state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) - } - } catch { - case c: ClosedChannelException => - // This can happen when underlying file output stream has been closed before the - // compression stream. - logDebug(s"Error aborting version $newVersion into $this", c) - - case e: Exception => - logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -212,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf - fs.mkdirs(baseDir) + fm.mkdirs(baseDir) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -251,31 +244,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val loadedMaps = new mutable.HashMap[Long, MapType] private lazy val baseDir = stateStoreId.storeCheckpointLocation() - private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - /** Commit a set of updates to the store with the given new version */ - private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { - val finalDeltaFile = deltaFile(newVersion) - - // scalastyle:off - // Renaming a file atop an existing one fails on HDFS - // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). - // Hence we should either skip the rename step or delete the target file. Because deleting the - // target file will break speculation, skipping the rename step is the only choice. It's still - // semantically correct because Structured Streaming requires rerunning a batch should - // generate the same output. (SPARK-19677) - // scalastyle:on - if (fs.exists(finalDeltaFile)) { - fs.delete(tempDeltaFile, true) - } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { - throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") - } + finalizeDeltaFile(output) loadedMaps.put(newVersion, map) - finalDeltaFile } } @@ -365,7 +342,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val fileToRead = deltaFile(version) var input: DataInputStream = null val sourceStream = try { - fs.open(fileToRead) + fm.open(fileToRead) } catch { case f: FileNotFoundException => throw new IllegalStateException( @@ -412,12 +389,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val tempFile = - new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") + val targetFile = snapshotFile(version) + var rawOutput: CancellableFSDataOutputStream = null var output: DataOutputStream = null - Utils.tryWithSafeFinally { - output = compressStream(fs.create(tempFile, false)) + try { + rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true) + output = compressStream(rawOutput) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -429,16 +406,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit output.write(valueBytes) } output.writeInt(-1) - } { - if (output != null) output.close() + output.close() + } catch { + case e: Throwable => + cancelDeltaFile(compressedStream = output, rawStream = rawOutput) + throw e } - if (fs.exists(fileToWrite)) { - // Skip rename if the file is alreayd created. - fs.delete(tempFile, true) - } else if (!fs.rename(tempFile, fileToWrite)) { - throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + logInfo(s"Written snapshot file for version $version of $this at $targetFile") + } + + /** + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + private def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. } - logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -447,7 +442,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit var input: DataInputStream = null try { - input = decompressStream(fs.open(fileToRead)) + input = decompressStream(fm.open(fileToRead)) var eof = false while (!eof) { @@ -508,7 +503,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case None => // The last map is not loaded, probably some other instance is in charge } - } } catch { case NonFatal(e) => @@ -534,7 +528,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) filesToDelete.foreach { f => - fs.delete(f.path, true) + fm.delete(f.path) } logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) @@ -576,7 +570,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) + fm.list(baseDir) } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6fe632f958ff..7eb68c21569b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -94,7 +94,7 @@ trait StateStore { def abort(): Unit /** - * Return an iterator containing all the key-value pairs in the SateStore. Implementations must + * Return an iterator containing all the key-value pairs in the StateStore. Implementations must * ensure that updates (puts, removes) can be made while iterating over this iterator. */ def iterator(): Iterator[UnsafeRowPair] @@ -459,7 +459,6 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - logInfo("Env is not null") val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER || env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -467,13 +466,12 @@ object StateStore extends Logging { // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. if (isDriver || _coordRef == null) { - logInfo("Getting StateStoreCoordinatorRef") + logDebug("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { - logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f..1691a6320a52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -126,6 +126,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -340,37 +346,35 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } @@ -390,6 +394,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -456,6 +466,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 7019d98e1619..582528777f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -39,7 +39,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() sqlStore.executionsList().foreach { e => - val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isRunning = e.completionTime.isEmpty || + e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } if (isRunning) { running += e @@ -179,7 +180,7 @@ private[ui] abstract class ExecutionTable( } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { - val details = if (execution.details.nonEmpty) { + val details = if (execution.details != null && execution.details.nonEmpty) { +details ++ @@ -190,8 +191,10 @@ private[ui] abstract class ExecutionTable( Nil } - val desc = { + val desc = if (execution.description != null && execution.description.nonEmpty) { {execution.description} + } else { + {execution.executionId} }
      {desc} {details}
      diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index f29e135ac357..e0554f0c4d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -80,7 +80,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging planVisualization(metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse { -
      No information to display for Plan {executionId}
      +
      No information to display for query {executionId}
      } UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d8adbe7bee13..d254af400a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -88,7 +88,7 @@ class SQLAppStatusListener( exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) exec.stages ++= event.stageIds.toSet - update(exec) + update(exec, force = true) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -99,7 +99,7 @@ class SQLAppStatusListener( // Reset the metrics tracking object for the new attempt. Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptId + metrics.attemptId = event.stageInfo.attemptNumber } } @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } @@ -308,11 +308,13 @@ class SQLAppStatusListener( }) } - private def update(exec: LiveExecutionData): Unit = { + private def update(exec: LiveExecutionData, force: Boolean = false): Unit = { val now = System.nanoTime() if (exec.endEvents >= exec.jobs.size + 1) { exec.write(kvstore, now) liveExecutions.remove(exec.executionId) + } else if (force) { + exec.write(kvstore, now) } else if (liveUpdatePeriodNs >= 0) { if (now - exec.lastWriteTime > liveUpdatePeriodNs) { exec.write(kvstore, now) @@ -332,9 +334,12 @@ class SQLAppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[SQLExecutionUIData]), - countToDelete.toInt) { e => e.completionTime.isDefined } - toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } + val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) + toDelete.foreach { e => + kvstore.delete(e.getClass(), e.executionId) + kvstore.delete(classOf[SparkPlanGraphWrapper], e.executionId) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 910f2e52fdbb..241001a857c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -23,11 +23,12 @@ import java.util.Date import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus import org.apache.spark.status.KVUtils.KVIndexParam -import org.apache.spark.util.kvstore.KVStore +import org.apache.spark.util.kvstore.{KVIndex, KVStore} /** * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's @@ -53,6 +54,10 @@ class SQLAppStatusStore( store.count(classOf[SQLExecutionUIData]) } + def planGraphCount(): Long = { + store.count(classOf[SparkPlanGraphWrapper]) + } + def executionMetrics(executionId: Long): Map[Long, String] = { def metricsFromStore(): Option[Map[Long, String]] = { val exec = store.read(classOf[SQLExecutionUIData], executionId) @@ -90,7 +95,11 @@ class SQLExecutionUIData( * from the SQL listener instance. */ @JsonDeserialize(keyAs = classOf[JLong]) - val metricValues: Map[Long, String]) + val metricValues: Map[Long, String]) { + + @JsonIgnore @KVIndex("completionTime") + private def completionTimeIndex: Long = completionTime.map(_.getTime).getOrElse(-1L) +} class SparkPlanGraphWrapper( @KVIndexParam val executionId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f399..626f39d9e95c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f83052..bdc4bb4422ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types.DataType * * As an example: * {{{ - * // Defined a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => if (score > 0.5) true else false) + * // Define a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => score > 0.5) * * // Projects a column that adds a prediction column based on the score column. * df.select( predict(df("score")) ) @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 1caa243f8d11..cd819bab1b14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.catalyst.expressions._ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) * }}} * + * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, + * unboundedFollowing) is used by default. When ordering is defined, a growing window frame + * (rangeFrame, unboundedPreceding, currentRow) is used by default. + * * @since 1.4.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01de..8f9e4ae18b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -132,7 +132,7 @@ object functions { * Returns a sort expression based on ascending order of the column, * and null values return before non-null values. * {{{ - * df.sort(asc_nulls_last("dept"), desc("age")) + * df.sort(asc_nulls_first("dept"), desc("age")) * }}} * * @group sort_funcs @@ -1313,8 +1313,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the cosine inverse of the given value; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1322,8 +1321,7 @@ object functions { def acos(e: Column): Column = withExpr { Acos(e.expr) } /** - * Computes the cosine inverse of the given column; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1331,8 +1329,7 @@ object functions { def acos(columnName: String): Column = acos(Column(columnName)) /** - * Computes the sine inverse of the given value; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1340,8 +1337,7 @@ object functions { def asin(e: Column): Column = withExpr { Asin(e.expr) } /** - * Computes the sine inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1349,8 +1345,7 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `e`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1358,8 +1353,7 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1367,77 +1361,117 @@ object functions { def atan(columnName: String): Column = atan(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). Units in radians. + * @param y coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } + def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) } /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, rightName: String): Column = - atan2(Column(leftName), Column(rightName)) + def atan2(yName: String, xName: String): Column = + atan2(Column(yName), Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) + def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l), r) + def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) /** * An expression that returns the string representation of the binary value of the given long @@ -1500,7 +1534,8 @@ object functions { } /** - * Computes the cosine of the given value. Units in radians. + * @param e angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1508,7 +1543,8 @@ object functions { def cos(e: Column): Column = withExpr { Cos(e.expr) } /** - * Computes the cosine of the given column. + * @param columnName angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1516,7 +1552,8 @@ object functions { def cos(columnName: String): Column = cos(Column(columnName)) /** - * Computes the hyperbolic cosine of the given value. + * @param e hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1524,7 +1561,8 @@ object functions { def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** - * Computes the hyperbolic cosine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1967,7 +2005,8 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. Units in radians. + * @param e angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1975,7 +2014,8 @@ object functions { def sin(e: Column): Column = withExpr { Sin(e.expr) } /** - * Computes the sine of the given column. + * @param columnName angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1983,7 +2023,8 @@ object functions { def sin(columnName: String): Column = sin(Column(columnName)) /** - * Computes the hyperbolic sine of the given value. + * @param e hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1991,7 +2032,8 @@ object functions { def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** - * Computes the hyperbolic sine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1999,7 +2041,8 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. Units in radians. + * @param e angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2007,7 +2050,8 @@ object functions { def tan(e: Column): Column = withExpr { Tan(e.expr) } /** - * Computes the tangent of the given column. + * @param columnName angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2015,7 +2059,8 @@ object functions { def tan(columnName: String): Column = tan(Column(columnName)) /** - * Computes the hyperbolic tangent of the given value. + * @param e hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2023,7 +2068,8 @@ object functions { def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** - * Computes the hyperbolic tangent of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2047,6 +2093,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param e angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2055,6 +2104,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param columnName angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2077,6 +2129,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param e angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2085,6 +2140,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param columnName angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2170,16 +2228,6 @@ object functions { */ def base64(e: Column): Column = withExpr { Base64(e.expr) } - /** - * Concatenates multiple input columns together into a single column. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } - /** * Concatenates multiple input string columns together into a single string column, * using the given separator. @@ -2267,7 +2315,9 @@ object functions { } /** - * Computes the length of a given string or binary column. + * Computes the character length of a given string or number of bytes of a binary string. + * The length of character strings include the trailing spaces. The length of binary strings + * includes binary zeros. * * @group string_funcs * @since 1.5.0 @@ -2404,14 +2454,6 @@ object functions { StringRepeat(str.expr, lit(n).expr) } - /** - * Reverses the string column and returns it as a new string column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } - /** * Trim the spaces from right end for the specified string value. * @@ -2649,11 +2691,22 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. + * The result is rounded off to 8 digits. * @group datetime_funcs * @since 1.5.0 */ def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + new MonthsBetween(date1.expr, date2.expr) + } + + /** + * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) } /** @@ -2871,7 +2924,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2927,7 +2980,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 @@ -2986,6 +3039,79 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) + } + + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } + + /** + * Locates the position of the first occurrence of the value in the given array as long. + * Returns null if either of the arguments are null. + * + * @note The position is not zero based, but 1 based index. Returns 0 if value + * could not be found in array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_position(column: Column, value: Any): Column = withExpr { + ArrayPosition(column.expr, Literal(value)) + } + + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** * Creates a new row for each element in the given array or map column. * @@ -3072,7 +3198,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(schema, options, e.expr) + new JsonToStructs(schema, options, e.expr) } /** @@ -3225,6 +3351,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3234,12 +3361,45 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the minimum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } + + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + + /** + * Returns a reversed string or an array with reverse order of elements. + * @group collection_funcs + * @since 1.5.0 + */ + def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs @@ -3254,42 +3414,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3486,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3502,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3518,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3534,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3550,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3566,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3582,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3598,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3614,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3630,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3645,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2867b4cd7da5..3a0db7e16c23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -130,8 +130,8 @@ abstract class BaseSessionStateBuilder( */ protected lazy val catalog: SessionCatalog = { val catalog = new SessionCatalog( - session.sharedState.externalCatalog, - session.sharedState.globalTempViewManager, + () => session.sharedState.externalCatalog, + () => session.sharedState.globalTempViewManager, functionRegistry, conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), @@ -206,7 +206,7 @@ abstract class BaseSessionStateBuilder( /** * Logical query plan optimizer. * - * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + * Note: this depends on `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { new SparkOptimizer(catalog, experimentalMethods) { @@ -263,7 +263,7 @@ abstract class BaseSessionStateBuilder( * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. * - * This gets cloned from parent if available, otherwise is a new instance is created. + * This gets cloned from parent if available, otherwise a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { parentState.map(_.listenerManager.clone()).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index b9515ec7bca2..eca612f06f9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -31,7 +31,8 @@ object HiveSerDe { "sequencefile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "rcfile" -> HiveSerDe( @@ -54,7 +55,8 @@ object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "avro" -> HiveSerDe( @@ -73,6 +75,7 @@ object HiveSerDe { val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc" case s if s.equals("orcfile") => "orc" case s if s.equals("parquetfile") => "parquet" case s if s.equals("avrofile") => "avro" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e..5b6160e2b408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2e92beecf2c1..ae93965bc50e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -27,8 +27,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} -import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -117,7 +118,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * You can set the following option(s): *
        *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
      • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
      * * @since 2.0.0 @@ -128,12 +129,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Adds input options for the underlying data source. + * (Java-specific) Adds input options for the underlying data source. * * You can set the following option(s): *
        *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
      • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
      * * @since 2.0.0 @@ -157,7 +158,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() - val options = new DataSourceV2Options(extraOptions.asJava) + val options = new DataSourceOptions(extraOptions.asJava) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. @@ -166,19 +167,41 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) + val v1Relation = ds match { + case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource)) + case _ => None + } ds match { + case s: MicroBatchReadSupport => + var tempReader: MicroBatchReader = null + val schema = try { + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReader != null) { + tempReader.stop() + tempReader = null + } + } + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( - java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Optional.ofNullable(userSpecifiedSchema.orNull), Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, options) - // Generate the V1 node to catch errors thrown within generation. - StreamingRelation(v1DataSource) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) @@ -224,12 +247,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
      • + *
      • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
      • *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • *
      @@ -245,6 +268,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • * * * @since 2.0.0 @@ -304,12 +329,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
      • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
      • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • *
      @@ -362,7 +389,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * Each line in the text files is a new row in the resulting DataFrame. For example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.readStream.text("/path/to/directory/") @@ -375,6 +402,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
        *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
      • + *
      • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
      • + *
      • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
      • *
      * * @since 2.0.0 @@ -388,7 +419,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * Each line in the text file is a new element in the resulting Dataset. For example: + * By default, each line in the text file is a new element in the resulting Dataset. For example: * {{{ * // Scala: * spark.readStream.textFile("/path/to/spark/README.md") @@ -401,6 +432,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
        *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
      • + *
      • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
      • + *
      • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
      • *
      * * @param path input path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f3..effc1471e8e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -268,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -279,18 +280,25 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") + val sink = ds.newInstance() match { + case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index b508f4406138..7cefd03e43bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,10 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -240,26 +240,27 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - sink match { - case v1Sink: Sink => - new StreamingQueryWrapper(new MicroBatchExecution( + (sink, trigger) match { + case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v1Sink, + v2Sink, trigger, triggerClock, outputMode, + extraOptions, deleteCheckpointOnStop)) - case v2Sink: ContinuousWriteSupport => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) - new StreamingQueryWrapper(new ContinuousExecution( + case _ => + new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + sink, trigger, triggerClock, outputMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index a73e4272950a..8bab7e1c5876 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false } + + override def toString(): String = s"($x, $y)" } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84..69a2904f5f3f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,15 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)) + .toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java index ddbaa45a483c..08dc129f27a0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -46,7 +46,7 @@ public void tearDown() { @SuppressWarnings("unchecked") @Test public void udf1Test() { - spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.range(1, 10).toDF("value").createOrReplaceTempView("df"); spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 1cfdc08217e6..172e5d5eebcb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,19 +24,20 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { - private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - private Filter[] filters = new Filter[0]; + // Exposed for testing. + public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + public Filter[] filters = new Filter[0]; @Override public StructType readSchema() { @@ -50,8 +51,26 @@ public void pruneColumns(StructType requiredSchema) { @Override public Filter[] pushFilters(Filter[] filters) { - this.filters = filters; - return new Filter[0]; + Filter[] supported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return gt.attribute().equals("i") && gt.value() instanceof Integer; + } else { + return false; + } + }).toArray(Filter[]::new); + + Filter[] unsupported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return !gt.attribute().equals("i") || !(gt.value() instanceof Integer); + } else { + return true; + } + }).toArray(Filter[]::new); + + this.filters = supported; + return unsupported; } @Override @@ -60,8 +79,8 @@ public Filter[] pushedFilters() { } @Override - public List> createReadTasks() { - List> res = new ArrayList<>(); + public List> createDataReaderFactories() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -75,25 +94,25 @@ public List> createReadTasks() { } if (lowerBound == null) { - res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedReadTask implements ReadTask, DataReader { + static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; @@ -101,7 +120,7 @@ static class JavaAdvancedReadTask implements ReadTask, DataReader { @Override public DataReader createDataReader() { - return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); } @Override @@ -131,7 +150,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 000000000000..c55093768105 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceReader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createBatchDataReaderFactories() { + return java.util.Arrays.asList( + new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); + } + } + + static class JavaBatchDataReaderFactory + implements DataReaderFactory, DataReader { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchDataReaderFactory(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createDataReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(vectors); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java new file mode 100644 index 000000000000..32fad59b97ff --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; +import org.apache.spark.sql.types.StructType; + +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceReader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createDataReaderFactories() { + return java.util.Arrays.asList( + new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + } + + @Override + public Partitioning outputPartitioning() { + return new MyPartitioning(); + } + } + + static class MyPartitioning implements Partitioning { + + @Override + public int numPartitions() { + return 2; + } + + @Override + public boolean satisfy(Distribution distribution) { + if (distribution instanceof ClusteredDistribution) { + String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; + return Arrays.asList(clusteredCols).contains("a"); + } + + return false; + } + } + + static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { + private int[] i; + private int[] j; + private int current = -1; + + SpecificDataReaderFactory(int[] i, int[] j) { + assert i.length == j.length; + this.i = i; + this.j = j; + } + + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } + + @Override + public DataReader createDataReader() { + return this; + } + } + + @Override + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index a174bd8092cb..048d078dfaac 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -20,16 +20,16 @@ import java.util.List; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema; Reader(StructType schema) { @@ -42,13 +42,13 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Collections.emptyList(); } } @Override - public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { return new Reader(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2d458b7f7e90..96f55b8a7681 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -23,16 +23,16 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new JavaSimpleReadTask(0, 5), - new JavaSimpleReadTask(5, 10)); + new JavaSimpleDataReaderFactory(0, 5), + new JavaSimpleDataReaderFactory(5, 10)); } } - static class JavaSimpleReadTask implements ReadTask, DataReader { + static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; - JavaSimpleReadTask(int start, int end) { + JavaSimpleDataReaderFactory(int start, int end) { this.start = start; this.end = end; } @Override public DataReader createDataReader() { - return new JavaSimpleReadTask(start - 1, end); + return new JavaSimpleDataReaderFactory(start - 1, end); } @Override @@ -80,7 +80,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index f6aa00869a68..c3916e0b370b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -21,15 +21,15 @@ import java.util.List; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader, SupportsScanUnsafeRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -38,19 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReadTasks() { + public List> createUnsafeRowReaderFactories() { return java.util.Arrays.asList( - new JavaUnsafeRowReadTask(0, 5), - new JavaUnsafeRowReadTask(5, 10)); + new JavaUnsafeRowDataReaderFactory(0, 5), + new JavaUnsafeRowDataReaderFactory(5, 10)); } } - static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + static class JavaUnsafeRowDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowReadTask(int start, int end) { + JavaUnsafeRowDataReaderFactory(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,7 +60,7 @@ static class JavaUnsafeRowReadTask implements ReadTask, DataReader createDataReader() { - return new JavaUnsafeRowReadTask(start - 1, end); + return new JavaUnsafeRowDataReaderFactory(start - 1, end); } @Override @@ -82,7 +83,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c6973bf41d34..46b38bed1c0f 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,10 @@ org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo org.apache.fakesource.FakeExternalSourceThree +org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly +org.apache.spark.sql.streaming.sources.FakeReadBothModes +org.apache.spark.sql.streaming.sources.FakeReadNeitherMode +org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeNoWrite +org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index ad0f885f63d3..2909024e4c9f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column -- Change column in partition spec (not supported yet) CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d); ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT; +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd..4950a4b7a4e5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -25,3 +25,38 @@ create temporary view ttf2 as select * from values select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 1e1384549a41..c5070b734d52 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -60,3 +60,12 @@ SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; -- Aggregate with empty input and empty GroupBy expressions. SELECT COUNT(1) FROM testData WHERE false; SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; + +-- Aggregate with empty GroupBy expressions and filter on top +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 37b4b7606d12..a743cf1ec2cd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -105,3 +105,6 @@ select X'XuZ'; -- Hive literal_double test. SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; + +-- map + interval test +select map(1, interval 1 day, 2, interval 3 week); diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 000000000000..01dea6c81c11 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,113 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index e99d5cef81f6..fadb4bb27fa1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -39,3 +39,10 @@ select 2.0 <= '2.2'; select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; + +-- SPARK-23549: Cast to timestamp when comparing timestamp with date +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00'); +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01'); +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01'); +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01'); +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index fb0d07fbdace..1661209093fc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -173,6 +173,16 @@ WHERE t1a = (SELECT max(t2a) HAVING count(*) >= 0) OR t1i > '2014-12-31'; +-- TC 02.03.01 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31'; + -- TC 02.04 -- t1 on the right of an outer join -- can be reduced to inner join diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 0beebec5702f..db00a18f2e7e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -91,3 +91,65 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +); + +-- Concatenate arrays of the same type +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays; + +-- Concatenate arrays of different types +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c8e108ac2c45..9be7fcdadfea 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -22,6 +22,52 @@ select a / b from t; select a % b from t; select pmod(a, b) from t; +-- tests for decimals handling in operations +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss are truncated +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; +select 123456789123456789.1234567890 * 1.123456789123456789; +select 12345678912345.123456789123 / 0.000000012345678; + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; @@ -29,5 +75,8 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; + +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 000000000000..717616f91db0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad03..6da1b9b49b22 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c4bea34ec4cf..cda4db4b449f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c7..ff1ecbcc44c2 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 33 -- !query 0 @@ -154,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))'; +Can't find column `invalid_col` given table data columns [`a`, `b`, `c`]; -- !query 16 @@ -291,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT -- !query 30 -DROP TABLE test_change +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C' -- !query 30 schema struct<> -- !query 30 output - +org.apache.spark.sql.AnalysisException +Can't find column `c` given table data columns [`a`, `b`]; -- !query 31 -DROP TABLE partition_table +DROP TABLE test_change -- !query 31 schema struct<> -- !query 31 output + + +-- !query 32 +DROP TABLE partition_table +-- !query 32 schema +struct<> +-- !query 32 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index b5a4f5c2bf65..539f673c9d67 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -195,7 +195,7 @@ SELECT t1.x.y.* FROM t1 struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 't1.x.y.*' give input columns 'i1'; +cannot resolve 't1.x.y.*' given input columns 'i1'; -- !query 23 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 7c451c2aa5b5..209211960095 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -88,7 +88,7 @@ SELECT global_temp.view1.* FROM global_temp.view1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' give input columns 'i1'; +cannot resolve 'global_temp.view1.*' given input columns 'i1'; -- !query 11 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index d3ca4443cce5..e10f516ad6e5 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -179,7 +179,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 22 @@ -212,7 +212,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 26 @@ -420,7 +420,7 @@ SELECT mydb1.t5.* FROM mydb1.t5 struct<> -- !query 50 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; +cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; -- !query 51 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index bbb6851e69c7..9eede305dbdc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 26 -- !query 0 @@ -81,3 +81,139 @@ struct -- !query 8 output 1 2 2 3 + + +-- !query 9 +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query 9 schema +struct +-- !query 9 output +5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 986bb01c13fe..c1abc6dff754 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 26 -- !query 0 @@ -227,3 +227,17 @@ SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t struct<1:int> -- !query 24 output 1 + + +-- !query 25 +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z +-- !query 25 schema +struct<1:int> +-- !query 25 output + diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d9dc728a18e8..14a69128ffb4 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 95d4413148f6..b8c91dc8b59a 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 43 +-- Number of queries: 44 -- !query 0 @@ -323,19 +323,17 @@ select timestamp '2016-33-11 20:54:00.000' -- !query 34 select interval 13.123456789 seconds, interval -13.123456789 second -- !query 34 schema -struct<> +struct -- !query 34 output -scala.MatchError -(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 13 seconds 123 milliseconds 456 microseconds interval -12 seconds -876 milliseconds -544 microseconds -- !query 35 select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond -- !query 35 schema -struct<> +struct -- !query 35 output -scala.MatchError -(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds 9 -- !query 36 @@ -416,3 +414,11 @@ SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> -- !query 42 output 3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 + + +-- !query 43 +select map(1, interval 1 day, 2, interval 3 week) +-- !query 43 schema +struct> +-- !query 43 output +{1:interval 1 days,2:interval 3 weeks} diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 237b618a8b90..840655b7a644 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -425,7 +425,7 @@ struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NUL -- !query 51 select BIT_LENGTH('abc') -- !query 51 schema -struct +struct -- !query 51 output 24 @@ -449,7 +449,7 @@ struct -- !query 54 select OCTET_LENGTH('abc') -- !query 54 schema -struct +struct -- !query 54 output 3 diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 000000000000..85e3488990e2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 13 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 2 schema +struct +-- !query 2 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 3 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 3 schema +struct +-- !query 3 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 4 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 4 schema +struct +-- !query 4 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 5 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +63000 50000 + + +-- !query 6 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 2012 50000 2012 + + +-- !query 7 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 7 schema +struct +-- !query 7 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 8 schema +struct +-- !query 8 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 10 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 10 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 10 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 11 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, found 'abs(earnings#x)'; + + +-- !query 12 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index d51f6d37e4b4..cf828c69af62 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 37 -- !query 0 @@ -256,3 +256,43 @@ select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> -- !query 31 output true + + +-- !query 32 +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00') +-- !query 32 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) = to_timestamp('2017-03-01 00:00:00')):boolean> +-- !query 32 output +true + + +-- !query 33 +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01') +-- !query 33 schema +struct<(to_timestamp('2017-03-01 00:00:01') > CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 33 output +true + + +-- !query 34 +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01') +-- !query 34 schema +struct<(to_timestamp('2017-03-01 00:00:01') >= CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 34 output +true + + +-- !query 35 +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01') +-- !query 35 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) < to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 35 output +true + + +-- !query 36 +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01') +-- !query 36 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) <= to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 36 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 8b29300e71f9..dd82efba0dde 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -293,6 +293,22 @@ val1d -- !query 19 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31' +-- !query 19 schema +struct +-- !query 19 output +val1c +val1d + + +-- !query 20 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -300,13 +316,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 19 schema +-- !query 20 schema struct --- !query 19 output +-- !query 20 output 7 --- !query 20 +-- !query 21 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -317,14 +333,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 20 schema +-- !query 21 schema struct --- !query 20 output +-- !query 21 output val1b val1c --- !query 21 +-- !query 22 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -338,14 +354,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 21 schema +-- !query 22 schema struct --- !query 21 output +-- !query 22 output val1b val1c --- !query 22 +-- !query 23 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -359,9 +375,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 23 schema struct --- !query 22 output +-- !query 23 output val1a val1a val1b @@ -372,7 +388,7 @@ val1d val1d --- !query 23 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -386,16 +402,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 24 schema struct --- !query 23 output +-- !query 24 output val1a val1b val1c val1d --- !query 24 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -409,13 +425,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 25 schema struct --- !query 24 output +-- !query 25 output val1a --- !query 25 +-- !query 26 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -423,8 +439,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 26 schema struct --- !query 25 output +-- !query 26 output val1b val1c diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 09729fdc2ec3..62befc5ca0f1 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -237,3 +237,81 @@ struct 78910 891011 9101112 + + +-- !query 11 +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays +-- !query 12 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,data_array:array,timestamp_array:array,string_array:array,array_array:array>,struct_array:array>,map_array:array>> +-- !query 12 output +[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}] + + +-- !query 13 +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays +-- !query 13 schema +struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +-- !query 13 output +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index ce02f6adc456..6bfdb84548d4 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 36 -- !query 0 @@ -35,48 +35,269 @@ NULL -- !query 4 -select (5e36 + 0.1) + 5e36 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet -- !query 4 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 4 output -NULL + -- !query 5 -select (-4e36 - 0.1) - 7e36 +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 5 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 5 output -NULL + -- !query 6 -select 12345678901234567890.0 * 12345678901234567890.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 6 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct -- !query 6 output -NULL +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 7 -select 1e35 / 0.1 +select id, a*10, b/10 from decimals_test order by id -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct -- !query 7 output -NULL +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 -- !query 8 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 10.3 * 3.0 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 8 output -NULL +30.9 -- !query 9 -select 0.001 / 9876543210987654321098765432109876543.2 +select 10.3000 * 3.0 -- !query 9 schema -struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 9 output +30.9 + + +-- !query 10 +select 10.30000 * 30.0 +-- !query 10 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 10 output +309 + + +-- !query 11 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 11 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 11 output +30.9 + + +-- !query 12 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 12 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 12 output +30.9 + + +-- !query 13 +select (5e36 + 0.1) + 5e36 +-- !query 13 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 13 output +NULL + + +-- !query 14 +select (-4e36 - 0.1) - 7e36 +-- !query 14 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 14 output +NULL + + +-- !query 15 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 15 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 15 output +NULL + + +-- !query 16 +select 1e35 / 0.1 +-- !query 16 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 16 output +NULL + + +-- !query 17 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 17 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> +-- !query 17 output +10012345678912345678912345678911.246907 + + +-- !query 18 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 18 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 18 output +138698367904130467.654320988515622621 + + +-- !query 19 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 19 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> +-- !query 19 output +1000000073899961059796.725866332 + + +-- !query 20 +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 20 schema +struct +-- !query 20 output +spark.sql.decimalOperations.allowPrecisionLoss false + + +-- !query 21 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 21 schema +struct +-- !query 21 output +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 + + +-- !query 22 +select id, a*10, b/10 from decimals_test order by id +-- !query 22 schema +struct +-- !query 22 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 + + +-- !query 23 +select 10.3 * 3.0 +-- !query 23 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 23 output +30.9 + + +-- !query 24 +select 10.3000 * 3.0 +-- !query 24 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 24 output +30.9 + + +-- !query 25 +select 10.30000 * 30.0 +-- !query 25 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 25 output +309 + + +-- !query 26 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 26 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +-- !query 26 output +30.9 + + +-- !query 27 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 27 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> +-- !query 27 output +NULL + + +-- !query 28 +select (5e36 + 0.1) + 5e36 +-- !query 28 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 28 output +NULL + + +-- !query 29 +select (-4e36 - 0.1) - 7e36 +-- !query 29 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 29 output NULL + + +-- !query 30 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 30 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 30 output +NULL + + +-- !query 31 +select 1e35 / 0.1 +-- !query 31 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +-- !query 31 output +NULL + + +-- !query 32 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 32 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 32 output +NULL + + +-- !query 33 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 33 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 33 output +NULL + + +-- !query 34 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 34 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 34 output +NULL + + +-- !query 35 +drop table decimals_test +-- !query 35 schema +struct<> +-- !query 35 output + diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out index ebc8201ed5a1..6ee7f59d6987 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -2329,7 +2329,7 @@ struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(C -- !query 280 SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t -- !query 280 schema -struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,18)> -- !query 280 output 1 @@ -2661,7 +2661,7 @@ struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BI -- !query 320 SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t -- !query 320 schema -struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,18)> -- !query 320 output 1 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 000000000000..b62e1b682604 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0c..b023df825d81 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 133458ae9303..4afbcd62853d 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val -- !query 17 schema -struct +struct,collect_set:array,skewness:double,kurtosis:double> -- !query 17 output -NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 -2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 -3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 -- !query 18 diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 000000000000..ce4117fd299d Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16LE.json differ diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 000000000000..cf4d29328b86 Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16WithBOM.json differ diff --git a/sql/core/src/test/resources/test-data/utf32BEWithBOM.json b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json new file mode 100644 index 000000000000..6c7733c57787 Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json differ diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql new file mode 100644 index 000000000000..50e521567eb3 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql @@ -0,0 +1,69 @@ +-- This is a new query in TPCDS v2.7 +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from + customer c,customer_address ca,customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and ca_county in ('Walker County', 'Richland County', 'Gaines County', 'Douglas County', 'Dona Ana County') + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales,date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) + and exists ( + select * + from ( + select + ws_bill_customer_sk as customer_sk, + d_year, + d_moy + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3 + union all + select + cs_ship_customer_sk as customer_sk, + d_year, + d_moy + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) x + where c.c_customer_sk = customer_sk) +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql new file mode 100755 index 000000000000..97bed3372174 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql @@ -0,0 +1,84 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT + -- select list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END +ORDER BY + -- order-by list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql new file mode 100755 index 000000000000..7a6fafd22428 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql @@ -0,0 +1,23 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql new file mode 100644 index 000000000000..b2ca3ddaf2ba --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql @@ -0,0 +1,135 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14a.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1998 AND 1998 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1998 AND 1998 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1998 AND 1998 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x) +select + * +from ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + 1 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year, + ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales)) last_year +where + this_year.i_brand_id = last_year.i_brand_id + and this_year.i_class_id = last_year.i_class_id + and this_year.i_category_id = last_year.i_category_id +order by + this_year.channel, + this_year.i_brand_id, + this_year.i_class_id, + this_year.i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql new file mode 100644 index 000000000000..bfa70fe62d8d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql @@ -0,0 +1,215 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14b.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1999 AND 1999 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1999 AND 1999 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1999 AND 1999 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1999 and 2001 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x), +results AS ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales) sum_sales, + sum(number_sales) number_sales + from ( + select + 'store' channel, + i_brand_id,i_class_id, + i_category_id, + sum(ss_quantity*ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales) + union all + select + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity*cs_list_price) sales, + count(*) number_sales + from + catalog_sales, item, date_dim + where + cs_item_sk in (select ss_item_sk from cross_items) + and cs_item_sk = i_item_sk + and cs_sold_date_sk = d_date_sk + and d_year = 1998+2 + and d_moy = 11 + group by + i_brand_id,i_class_id,i_category_id + having + sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales) + union all + select + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity*ws_list_price) sales, + count(*) number_sales + from + web_sales, item, date_dim + where + ws_item_sk in (select ss_item_sk from cross_items) + and ws_item_sk = i_item_sk + and ws_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales)) y + group by + channel, + i_brand_id, + i_class_id, + i_category_id) +select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales +from ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales + from + results + union + select + channel, + i_brand_id, + i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id, + i_class_id + union + select + channel, + i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id + union + select + channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel + union + select + null as channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results) z +order by + channel, + i_brand_id, + i_class_id, + i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql new file mode 100644 index 000000000000..2201a302ab35 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql @@ -0,0 +1,133 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + cast(cs_quantity as decimal(12,2)) agg1, + cast(cs_list_price as decimal(12,2)) agg2, + cast(cs_coupon_amt as decimal(12,2)) agg3, + cast(cs_sales_price as decimal(12,2)) agg4, + cast(cs_net_profit as decimal(12,2)) agg5, + cast(c_birth_year as decimal(12,2)) agg6, + cast(cd1.cd_dep_count as decimal(12,2)) agg7 + from + catalog_sales, customer_demographics cd1, customer_demographics cd2, customer, + customer_address, date_dim, item + where + cs_sold_date_sk = d_date_sk + and cs_item_sk = i_item_sk + and cs_bill_cdemo_sk = cd1.cd_demo_sk + and cs_bill_customer_sk = c_customer_sk + and cd1.cd_gender = 'M' + and cd1.cd_education_status = 'College' + and c_current_cdemo_sk = cd2.cd_demo_sk + and c_current_addr_sk = ca_address_sk + and c_birth_month in (9,5,12,4,1,10) + and d_year = 2001 + and ca_state in ('ND','WI','AL','NC','OK','MS','TN')) +select + i_item_id, + ca_country, + ca_state, + ca_county, + agg1, + agg2, + agg3, + agg4, + agg5, + agg6, + agg7 +from ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state, + ca_county + union all + select + i_item_id, + ca_country, + ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state + union all + select + i_item_id, + ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id, + ca_country + union all + select + i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results) foo +order by + ca_country, + ca_state, + ca_county, + i_item_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql new file mode 100755 index 000000000000..34d46b1394d8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql new file mode 100755 index 000000000000..e7bea0804f16 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql @@ -0,0 +1,15 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + -- q22 in TPCDS v1.4 had a condition below: + -- AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql new file mode 100644 index 000000000000..c886e6271511 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql @@ -0,0 +1,94 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh + from + inventory, date_dim, item, warehouse + where + inv_date_sk = d_date_sk + and inv_item_sk = i_item_sk + and inv_warehouse_sk = w_warehouse_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_product_name, + i_brand, + i_class, + i_category), +results_rollup as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class, + i_category + union all + select + i_product_name, + i_brand, + i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class + union all + select + i_product_name, + i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand + union all + select + i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name + union all + select + null i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results) +select + i_product_name, + i_brand, + i_class, + i_category, + qoh +from + results_rollup +order by + qoh, + i_product_name, + i_brand, + i_class, + i_category +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql new file mode 100755 index 000000000000..92d64bc7eba7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql @@ -0,0 +1,40 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_current_addr_sk = ca_address_sk -- This condition did not exist in TPCDS v1.4 + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) +-- no order-by exists in q24a of TPCDS v1.4 +ORDER BY + c_last_name, + c_first_name, + s_store_name diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql new file mode 100644 index 000000000000..c70a2420e838 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + s_state, 0 as g_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + from + store_sales, customer_demographics, date_dim, store, item + where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_store_sk = s_store_sk + and ss_cdemo_sk = cd_demo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and d_year = 1998 + and s_state in ('TN','TN', 'TN', 'TN', 'TN', 'TN')) +select + i_item_id, + s_state, + g_state, + agg1, + agg2, + agg3, + agg4 +from ( + select + i_item_id, + s_state, + 0 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results + group by + i_item_id, + s_state + union all + select + i_item_id, + NULL AS s_state, + 1 AS g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as s_state, + 1 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results) foo +order by + i_item_id, + s_state +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql new file mode 100755 index 000000000000..bbede62acc9a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql @@ -0,0 +1,37 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag DESC, + ss_ticket_number -- This order-by condition did not exist in TPCDS v1.4 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql new file mode 100755 index 000000000000..27116a563d5c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql @@ -0,0 +1,65 @@ +SELECT + -- select list of q35 in TPCDS v1.4 is below: + -- ca_state, + -- cd_gender, + -- cd_marital_status, + -- count(*) cnt1, + -- min(cd_dep_count), + -- max(cd_dep_count), + -- avg(cd_dep_count), + -- cd_dep_employed_count, + -- count(*) cnt2, + -- min(cd_dep_employed_count), + -- max(cd_dep_employed_count), + -- avg(cd_dep_employed_count), + -- cd_dep_college_count, + -- count(*) cnt3, + -- min(cd_dep_college_count), + -- max(cd_dep_college_count), + -- avg(cd_dep_college_count) + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql new file mode 100644 index 000000000000..1c1463e44777 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql @@ -0,0 +1,62 @@ +-- This is a new query in TPCDS v2.7 +select + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +from + customer c, customer_address ca, customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales, date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) + and exists ( + select * + from ( + select ws_bill_customer_sk customsk + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4 + union all + select cs_ship_customer_sk customsk + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) x + where x.customsk = c.c_customer_sk) +group by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql new file mode 100644 index 000000000000..9d98f32add50 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as ss_net_profit, + sum(ss_ext_sales_price) as ss_ext_sales_price, + sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin, + i_category, + i_class, + 0 as g_category, + 0 as g_class + from + store_sales, + date_dim d1, + item, + store + where + d1.d_year = 2001 + and d1.d_date_sk = ss_sold_date_sk + and i_item_sk = ss_item_sk + and s_store_sk = ss_store_sk + and s_state in ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') + group by + i_category, + i_class), + results_rollup as ( + select + gross_margin, + i_category, + i_class, + 0 as t_category, + 0 as t_class, + 0 as lochierarchy + from + results + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + i_category, NULL AS i_class, + 0 as t_category, + 1 as t_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + NULL AS i_category, + NULL AS i_class, + 1 as t_category, + 1 as t_class, + 2 as lochierarchy + from + results) +select + gross_margin, + i_category, + i_class, + lochierarchy, + rank() over ( + partition by lochierarchy, case when t_class = 0 then i_category end + order by gross_margin asc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql new file mode 100755 index 000000000000..9f7ee457ea45 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql @@ -0,0 +1,64 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + -- q47 in TPCDS v1.4 had more columns below: + -- v1.i_brand, + -- v1.s_store_name, + -- v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql new file mode 100755 index 000000000000..e8061bde4159 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql @@ -0,0 +1,133 @@ +-- The first SELECT query below is different from q49 of TPCDS v1.4 +SELECT + channel, + item, + return_ratio, + return_rank, + currency_rank +FROM ( + SELECT + 'web' as channel, + in_web.item, + in_web.return_ratio, + in_web.return_rank, + in_web.currency_rank + FROM + (SELECT + item, + return_ratio, + currency_ratio, + rank() over (ORDER BY return_ratio) AS return_rank, + rank() over (ORDER BY currency_ratio) AS currency_rank + FROM ( + SELECT + ws.ws_item_sk AS item, + CAST(SUM(COALESCE(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_quantity, 0)) AS DECIMAL(15, 4)) AS return_ratio, + CAST(SUM(COALESCE(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_net_paid, 0)) AS DECIMAL(15, 4)) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND ws.ws_item_sk = wr.wr_item_sk), + date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY + ws.ws_item_sk) + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY + -- order-by list of q49 in TPCDS v1.4 is below: + -- 1, 4, 5 + 1, 4, 5, 2 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql new file mode 100644 index 000000000000..b8cbbbc8ef7d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql @@ -0,0 +1,103 @@ +-- This is a new query in TPCDS v2.7 +WITH web_tv as ( + select + ws_item_sk item_sk, + d_date, + sum(ws_sales_price) sumws, + row_number() over (partition by ws_item_sk order by d_date) rk + from + web_sales, date_dim + where + ws_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ws_item_sk is not NULL + group by + ws_item_sk, d_date), +web_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumws, + sum(v2.sumws) cume_sales + from + web_tv v1, web_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumws), +store_tv as ( + select + ss_item_sk item_sk, + d_date, + sum(ss_sales_price) sumss, + row_number() over (partition by ss_item_sk order by d_date) rk + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_item_sk is not NULL + group by ss_item_sk, d_date), +store_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumss, + sum(v2.sumss) cume_sales + from + store_tv v1, store_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumss), +v as ( + select + item_sk, + d_date, + web_sales, + store_sales, + row_number() over (partition by item_sk order by d_date) rk + from ( + select + case when web.item_sk is not null + then web.item_sk + else store.item_sk end item_sk, + case when web.d_date is not null + then web.d_date + else store.d_date end d_date, + web.cume_sales web_sales, + store.cume_sales store_sales + from + web_v1 web full outer join store_v1 store + on (web.item_sk = store.item_sk and web.d_date = store.d_date))) +select * +from ( + select + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales, + max(v2.web_sales) web_cumulative, + max(v2.store_sales) store_cumulative + from + v v1, v v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales) x +where + web_cumulative > store_cumulative +order by + item_sk, + d_date +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql new file mode 100755 index 000000000000..ccefaac3c12c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql @@ -0,0 +1,57 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + -- q57 in TPCDS v1.4 had a column below: + -- v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql new file mode 100644 index 000000000000..42bcf59c2aeb --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql @@ -0,0 +1,158 @@ +-- This is a new query in TPCDS v2.7 +with ssr as( + select + s_store_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ss_store_sk as store_sk, + ss_sold_date_sk as date_sk, + ss_ext_sales_price as sales_price, + ss_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + store_sales + union all + select + sr_store_sk as store_sk, + sr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + sr_return_amt as return_amt, + sr_net_loss as net_loss + from + store_returns) salesreturns, + date_dim, + store + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and store_sk = s_store_sk + group by + s_store_id), +csr as ( + select + cp_catalog_page_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + cs_catalog_page_sk as page_sk, + cs_sold_date_sk as date_sk, + cs_ext_sales_price as sales_price, + cs_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from catalog_sales + union all + select + cr_catalog_page_sk as page_sk, + cr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + cr_return_amount as return_amt, + cr_net_loss as net_loss + from catalog_returns) salesreturns, + date_dim, + catalog_page + where + date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and page_sk = cp_catalog_page_sk + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ws_web_site_sk as wsr_web_site_sk, + ws_sold_date_sk as date_sk, + ws_ext_sales_price as sales_price, + ws_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + web_sales + union all + select + ws_web_site_sk as wsr_web_site_sk, + wr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + wr_return_amt as return_amt, + wr_net_loss as net_loss + from + web_returns + left outer join web_sales on ( + wr_item_sk = ws_item_sk and wr_order_number = ws_order_number) + ) salesreturns, + date_dim, + web_site + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and wsr_web_site_sk = web_site_sk + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || s_store_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || cp_catalog_page_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + wsr) x + group by + channel, id) +select + channel, id, sales, returns, profit +from ( + select channel, id, sales, returns, profit + from results + union + select channel, null as id, sum(sales), sum(returns), sum(profit) + from results + group by channel + union + select null as channel, null as id, sum(sales), sum(returns), sum(profit) + from results) foo + order by channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql new file mode 100755 index 000000000000..c0bfa40ad44a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql @@ -0,0 +1,23 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +-- order-by list of q6 in TPCDS v1.4 is below: +-- order by cnt +order by cnt, a.ca_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql new file mode 100755 index 000000000000..cdcd8486b363 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql @@ -0,0 +1,111 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY + i_product_name, + i_item_sk, + s_store_name, + s_zip, + ad1.ca_street_number, + ad1.ca_street_name, + ad1.ca_city, + ad1.ca_zip, + ad2.ca_street_number, + ad2.ca_street_name, + ad2.ca_city, + ad2.ca_zip, + d1.d_year, + d2.d_year, + d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY + cs1.product_name, + cs1.store_name, + cs2.cnt, + -- The two columns below are newly added in TPCDS v2.7 + cs1.s1, + cs2.s1 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql new file mode 100644 index 000000000000..70a14043bbb3 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql @@ -0,0 +1,208 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + from + store_sales, date_dim, store, item + where + ss_sold_date_sk=d_date_sk + and ss_item_sk=i_item_sk + and ss_store_sk = s_store_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id), +results_rollup as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales + from + results + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year + union all + select + i_category, + i_class, + i_brand, + i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name + union all + select + i_category, + i_class, + i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand + union all + select + i_category, + i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class + union all + select + i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from results + group by + i_category + union all + select + null i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results) +select + * +from ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() over (partition by i_category order by sumsales desc) rk + from results_rollup) dw2 +where + rk <= 100 +order by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rk +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql new file mode 100644 index 000000000000..4aec9c7fd1fd --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as total_sum, + s_state ,s_county, + 0 as gstate, + 0 as g_county + from + store_sales, date_dim d1, store + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_state in ( + select s_state + from ( + select + s_state as s_state, + rank() over (partition by s_state order by sum(ss_net_profit) desc) as ranking + from store_sales, store, date_dim + where d_month_seq between 1212 and 1212 + 11 + and d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + group by s_state) tmp1 + where ranking <= 5) + group by + s_state, s_county), +results_rollup as ( + select + total_sum, + s_state, + s_county, + 0 as g_state, + 0 as g_county, + 0 as lochierarchy + from results + union + select + sum(total_sum) as total_sum,s_state, + NULL as s_county, + 0 as g_state, + 1 as g_county, + 1 as lochierarchy + from results + group by s_state + union + select + sum(total_sum) as total_sum, + NULL as s_state, + NULL as s_county, + 1 as g_state, + 1 as g_county, + 2 as lochierarchy + from results) +select + total_sum, + s_state, + s_county, + lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_county = 0 then s_state end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then s_state end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql new file mode 100755 index 000000000000..066d6a587e91 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql @@ -0,0 +1,40 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +-- q72 in TPCDS v1.4 had conditions below: +-- WHERE d1.d_week_seq = d2.d_week_seq +-- AND inv_quantity_on_hand < cs_quantity +-- AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) +-- AND hd_buy_potential = '>10000' +-- AND d1.d_year = 1999 +-- AND hd_buy_potential = '>10000' +-- AND cd_marital_status = 'D' +-- AND d1.d_year = 1999 +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > d1.d_date + INTERVAL 5 days + AND hd_buy_potential = '1001-5000' + AND d1.d_year = 2001 + AND cd_marital_status = 'M' +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql new file mode 100755 index 000000000000..94a0063b36c0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql @@ -0,0 +1,60 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +-- order-by list of q74 in TPCDS v1.4 is below: +-- ORDER BY 1, 1, 1 +ORDER BY 2, 1, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql new file mode 100755 index 000000000000..ae5dc97ef231 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql @@ -0,0 +1,78 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY + sales_cnt_diff, + sales_amt_diff -- This order-by condition did not exist in TPCDS v1.4 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql new file mode 100644 index 000000000000..fc69c43470f1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql @@ -0,0 +1,121 @@ +-- This is a new query in TPCDS v2.7 +with ss as ( + select + s_store_sk, + sum(ss_ext_sales_price) as sales, + sum(ss_net_profit) as profit + from + store_sales, date_dim, store + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + group by + s_store_sk), +sr as ( + select + s_store_sk, + sum(sr_return_amt) as returns, + sum(sr_net_loss) as profit_loss + from + store_returns, date_dim, store + where + sr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and sr_store_sk = s_store_sk + group by + s_store_sk), +cs as ( + select + cs_call_center_sk, + sum(cs_ext_sales_price) as sales, + sum(cs_net_profit) as profit + from + catalog_sales, + date_dim + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + group by + cs_call_center_sk), + cr as ( + select + sum(cr_return_amount) as returns, + sum(cr_net_loss) as profit_loss + from catalog_returns, + date_dim + where + cr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days)), +ws as ( select wp_web_page_sk, + sum(ws_ext_sales_price) as sales, + sum(ws_net_profit) as profit + from web_sales, + date_dim, + web_page + where ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_page_sk = wp_web_page_sk + group by wp_web_page_sk), + wr as + (select wp_web_page_sk, + sum(wr_return_amt) as returns, + sum(wr_net_loss) as profit_loss + from web_returns, + date_dim, + web_page + where wr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and wr_web_page_sk = wp_web_page_sk + group by wp_web_page_sk) + , + results as + (select channel + , id + , sum(sales) as sales + , sum(returns) as returns + , sum(profit) as profit + from + (select 'store channel' as channel + , ss.s_store_sk as id + , sales + , coalesce(returns, 0) as returns + , (profit - coalesce(profit_loss,0)) as profit + from ss left join sr + on ss.s_store_sk = sr.s_store_sk + union all + select 'catalog channel' as channel + , cs_call_center_sk as id + , sales + , returns + , (profit - profit_loss) as profit + from cs + , cr + union all + select 'web channel' as channel + , ws.wp_web_page_sk as id + , sales + , coalesce(returns, 0) returns + , (profit - coalesce(profit_loss,0)) as profit + from ws left join wr + on ws.wp_web_page_sk = wr.wp_web_page_sk + ) x + group by channel, id ) + + select * + from ( + select channel, id, sales, returns, profit from results + union + select channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results group by channel + union + select NULL AS channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results +) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql new file mode 100755 index 000000000000..d03d8af77174 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql @@ -0,0 +1,75 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + -- order-by list of q78 in TPCDS v1.4 is below: + -- ratio, + -- ss_qty DESC, ss_wc DESC, ss_sp DESC, + -- other_chan_qty, + -- other_chan_wholesale_cost, + -- other_chan_sales_price, + -- round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) + ss_sold_year, + ss_item_sk, + ss_customer_sk, + ss_qty desc, + ss_wc desc, + ss_sp desc, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + ratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql new file mode 100644 index 000000000000..686e03ba2a6d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql @@ -0,0 +1,147 @@ +-- This is a new query in TPCDS v2.7 +with ssr as ( + select + s_store_id as store_id, + sum(ss_ext_sales_price) as sales, + sum(coalesce(sr_return_amt, 0)) as returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit + from + store_sales left outer join store_returns on ( + ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number), + date_dim, + store, + item, + promotion + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + and ss_item_sk = i_item_sk + and i_current_price > 50 + and ss_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + s_store_id), +csr as ( + select + cp_catalog_page_id as catalog_page_id, + sum(cs_ext_sales_price) as sales, + sum(coalesce(cr_return_amount, 0)) as returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit + from + catalog_sales left outer join catalog_returns on + (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number), + date_dim, + catalog_page, + item, + promotion + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and cs_catalog_page_sk = cp_catalog_page_sk + and cs_item_sk = i_item_sk + and i_current_price > 50 + and cs_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(ws_ext_sales_price) as sales, + sum(coalesce(wr_return_amt, 0)) as returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit + from + web_sales left outer join web_returns on ( + ws_item_sk = wr_item_sk and ws_order_number = wr_order_number), + date_dim, + web_site, + item, + promotion + where + ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_site_sk = web_site_sk + and ws_item_sk = i_item_sk + and i_current_price > 50 + and ws_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || store_id as id, + sales, + returns, + profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || catalog_page_id as id, + sales, + returns, + profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + profit + from + wsr) x + group by + channel, id) +select + channel, + id, + sales, + returns, + profit +from ( + select + channel, + id, + sales, + returns, + profit + from + results + union + select + channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results + group by + channel + union + select + NULL AS channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql new file mode 100644 index 000000000000..fff76b08d4ba --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql @@ -0,0 +1,61 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ws_net_paid) as total_sum, + i_category, i_class, + 0 as g_category, + 0 as g_class + from + web_sales, date_dim d1, item + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ws_sold_date_sk + and i_item_sk = ws_item_sk + group by + i_category, i_class), +results_rollup as( + select + total_sum, + i_category, + i_class, + g_category, + g_class, + 0 as lochierarchy + from + results + union + select + sum(total_sum) as total_sum, + i_category, + NULL as i_class, + 0 as g_category, + 1 as g_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(total_sum) as total_sum, + NULL as i_category, + NULL as i_class, + 1 as g_category, + 1 as g_class, + 2 as lochierarchy + from + results) +select + total_sum, + i_category ,i_class, lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_class = 0 then i_category end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql new file mode 100755 index 000000000000..771117add2ed --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql @@ -0,0 +1,22 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb..d635912cf720 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index 7037749f1447..e51aad021fcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -46,7 +46,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B override def beforeAll() { super.beforeAll() - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } protected def checkGeneratedCode(plan: SparkPlan): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e52445f28fc..81b7e18773f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -21,15 +21,15 @@ import scala.collection.mutable.HashSet import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.concurrent.Eventually._ - import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{AccumulatorContext, Utils} @@ -53,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -79,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } @@ -201,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -782,4 +782,42 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedDs2) == 1) } } + + test("SPARK-23312: vectorized cache reader can be disabled") { + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + val df = spark.range(10).cache() + df.queryExecution.executedPlan.foreach { + case i: InMemoryTableScanExec => + assert(i.supportsBatch == vectorized && i.supportCodegen == vectorized) + case _ => + } + } + } + } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala new file mode 100644 index 000000000000..b74fe2f90df2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSQLContext + +class ComplexTypesSuite extends QueryTest with SharedSQLContext { + + override def beforeAll() { + super.beforeAll() + spark.range(10).selectExpr( + "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5") + .write.saveAsTable("tab") + } + + override def afterAll() { + try { + spark.sql("DROP TABLE IF EXISTS tab") + } finally { + super.afterAll() + } + } + + def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = { + var count = 0 + plan.foreach { operator => + operator.transformExpressions { + case c: CreateNamedStruct => + count += 1 + c + } + } + + if (expectedCount != count) { + fail(s"expect $expectedCount CreateNamedStruct but got $count.") + } + } + + test("simple case") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2") + .filter("col2.c > 11").selectExpr("col1.a") + checkAnswer(df, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("named_struct is used in the top Project") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .selectExpr("col1.a", "col1") + .filter("col1.a > 8") + checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1) + + val df1 = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .sort("col1") + .selectExpr("col1.a") + .filter("col1.a > 8") + checkAnswer(df1, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1) + } + + test("expression in named_struct") { + val df = spark.table("tab") + .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola") + .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10") + checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola") + .selectExpr("cola.i3").filter("cola.exp > 10") + checkAnswer(df1, Row(12) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("nested case") { + val df = spark.table("tab") + .selectExpr("struct(struct(i2, i3) as exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10") + checkAnswer(df, Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("struct(i2, i3) as exp", "i4") + .selectExpr("struct(exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11") + checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index cee85ec8af04..949505e449fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id) + val data = spark.range(0, n, 1, 1).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 06848e4d2b29..e7776e36702a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import scala.util.Random +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -27,7 +29,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.DecimalType case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -456,7 +458,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), Row(null, null, null, null, null)) @@ -666,4 +667,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(exchangePlans.length == 1) } } + + Seq(true, false).foreach { codegen => + test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " + + s"results when codegen is enabled: $codegen") { + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) { + // explicit global aggregations + val emptyAgg = Map.empty[String, String] + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + + // global aggregation is converted to grouping aggregation: + assert(spark.emptyDataFrame.dropDuplicates().count() == 0) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f45..ecce06f4c075 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,7 +276,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +286,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,6 +324,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } test("array size function") { @@ -341,6 +365,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(-1)) ) + + checkAnswer( + df.selectExpr("cardinality(a)"), + Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) + ) } test("map size function") { @@ -413,6 +442,407 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + } + + test("array_min function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(1), Row(null), Row(null), Row(-100)) + + checkAnswer(df.select(array_min(df("a"))), answer) + checkAnswer(df.selectExpr("array_min(a)"), answer) + } + + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + + test("reverse function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + + // String test cases + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + + checkAnswer( + oneRowDF.select(reverse('s)), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(s)"), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.select(reverse('i)), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(i)"), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(null)"), + Seq(Row(null)) + ) + + // Array test cases (primitive-type elements) + val idf = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + + // Array test cases (non-primitive-type elements) + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(struct(1, 'a'))") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(map(1, 'a'))") + } + } + + test("array position function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + checkAnswer( + df.select(array_position(df("a"), 1)), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.selectExpr("array_position(a, 1)"), + Seq(Row(1L), Row(0L)) + ) + + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("array_position(a, null)"), + Seq(Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L), Row(1L)) + ) + checkAnswer( + df.selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L), Row(1L)) + ) + } + + test("element_at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } + + test("concat function - arrays") { + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + + // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) + } + + test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index aef0d7f3e425..0d9eeabb397a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class DataFrameJoinSuite extends QueryTest with SharedSQLContext { @@ -274,4 +275,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { checkAnswer(innerJoin, Row(1) :: Nil) } + test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + + "is false or null") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + df.join(dfNull, $"id" === $"b", "left").queryExecution.optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 45afbd29d190..b0b46640ff31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -23,8 +23,8 @@ import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -153,23 +153,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds)) { - assert(DataFrameRangeSuite.stageToKill > 0) - } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) } } sparkContext.addSparkListener(listener) for (codegen <- Seq(true, false)) { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(1000000000L).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } ex.getCause() match { case null => @@ -180,10 +174,13 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } } + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) eventually(timeout(20.seconds)) { assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -203,7 +200,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2..8eae35325fae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -154,24 +154,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) - val error_single = 2 * 1000 * epsilon - val error_double = 2 * 2000 * epsilon + val errorSingle = 1000 * epsilon + val errorDouble = 2.0 * errorSingle - assert(math.abs(single1 - q1 * n) < error_single) - assert(math.abs(double2 - 2 * q2 * n) < error_double) - assert(math.abs(s1 - q1 * n) < error_single) - assert(math.abs(s2 - q2 * n) < error_single) - assert(math.abs(d1 - 2 * q1 * n) < error_double) - assert(math.abs(d2 - 2 * q2 * n) < error_double) + assert(math.abs(single1 - q1 * n) <= errorSingle) + assert(math.abs(double2 - 2 * q2 * n) <= errorDouble) + assert(math.abs(s1 - q1 * n) <= errorSingle) + assert(math.abs(s2 - q2 * n) <= errorSingle) + assert(math.abs(d1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(d2 - 2 * q2 * n) <= errorDouble) // Multiple columns val Array(Array(ms1, ms2), Array(md1, md2)) = df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) - assert(math.abs(ms1 - q1 * n) < error_single) - assert(math.abs(ms2 - q2 * n) < error_single) - assert(math.abs(md1 - 2 * q1 * n) < error_double) - assert(math.abs(md2 - 2 * q2 * n) < error_double) + assert(math.abs(ms1 - q1 * n) <= errorSingle) + assert(math.abs(ms2 - q2 * n) <= errorSingle) + assert(math.abs(md1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(md2 - 2 * q2 * n) <= errorDouble) } // quantile should be in the range [0.0, 1.0] @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5e4c1a6a484f..60e84e6ee750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -589,6 +590,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + test("except distinct - SQL compliance") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") @@ -1255,6 +1264,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } + test("SPARK-23023 Cast rows to strings in showString") { + val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a") + assert(df1.showString(10) === + s"""+------------+ + || a| + |+------------+ + ||[1, 2, 3, 4]| + |+------------+ + |""".stripMargin) + val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a") + assert(df2.showString(10) === + s"""+----------------+ + || a| + |+----------------+ + ||[1 -> a, 2 -> b]| + |+----------------+ + |""".stripMargin) + val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") + assert(df3.showString(10) === + s"""+------+---+ + || a| b| + |+------+---+ + ||[1, a]| 0| + ||[2, b]| 0| + |+------+---+ + |""".stripMargin) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| @@ -2018,11 +2055,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; @@ -2228,4 +2260,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + + test("Uuid expressions should produce same results at retries in the same DataFrame") { + val df = spark.range(1).select($"id", new Column(Uuid())) + checkAnswer(df, df.collect()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala new file mode 100644 index 000000000000..2a0b2b85e10a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Window frame testing for DataFrame API. + */ +class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("lead/lag with empty data frame") { + val df = Seq.empty[(Int, String)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + lead("value", 1).over(window), + lag("value", 1).over(window)), + Nil) + } + + test("lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) + } + + test("reverse lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) + } + + test("lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) + } + + test("reverse lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil) + } + + test("lead/lag with default value") { + val default = "n/a" + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 2, default).over(window), + lag("value", 2, default).over(window), + lead("value", -2, default).over(window), + lag("value", -2, default).over(window)), + Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: + Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: + Row(2, default, default, default, default) :: Nil) + } + + test("rows/range between with empty data frame") { + val df = Seq.empty[(String, Int)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + first("value").over( + window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + first("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Nil) + } + + test("rows between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept at most one ORDER BY expression when unbounded") { + val df = Seq((1, 1)).toDF("key", "value") + val window = Window.orderBy($"key", $"value") + + checkAnswer( + df.select( + $"key", + min("key").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq(Row(1, 1)) + ) + + val e1 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e2 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e3 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + } + + test("range between should accept numeric values only when bounded") { + val df = Seq("non_numeric").toDF("value") + val window = Window.orderBy($"value") + + checkAnswer( + df.select( + $"value", + min("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) + + val e1 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("The data type of the upper bound 'string' " + + "does not match the expected data type")) + + val e2 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + + val e3 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + } + + test("range between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + + def dt(date: String): Date = Date.valueOf(date) + + val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), + (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)) + + checkAnswer( + df2.select( + $"key", + count("key").over(window)), + Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), + Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) + ) + } + + test("range between should accept double values as boundary") { + val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"), + (100.001D, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D)) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) + ) + } + + test("range between should accept interval values as boundary") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours"))) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), + Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) + ) + } + + test("unbounded rows/range between with aggregation") { + val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + sum("value").over(window. + rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + sum("value").over(window. + rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) + } + + test("unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) :: + Row(4, 4, 4) :: Nil) + } + + test("reverse unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc) + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) :: + Row(1, 1, 1) :: Nil) + } + + test("unbounded preceding/following range between with aggregation") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy("value").orderBy("key") + + checkAnswer( + df.select( + $"key", + avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1)) + .as("avg_key1"), + avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing)) + .as("avg_key2")), + Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) :: + Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) :: + Row(6, 17.0d / 4.0d, 6.0d) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse preceding/following range between with aggregation") { + val df = Seq(1, 2, 4, 3, 2, 1).toDF("value") + val window = Window.orderBy($"value".desc) + + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)), + sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: + Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("reverse sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("sliding range between with aggregation") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) :: + Row(2, 2.0d) :: Row(2, 2.0d) :: Nil) + } + + test("reverse sliding range between with aggregation") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window.partitionBy($"category").orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 01c988ecc372..3ea398aad737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import scala.collection.mutable + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ @@ -55,56 +57,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } - test("Window.rowsBetween") { - val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") - // Running (cumulative) sum - checkAnswer( - df.select('key, sum("value").over( - Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row("one", 1) :: Row("two", 3) :: Nil - ) - } - - test("lead") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) - } - - test("lag") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) - } - - test("lead with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) - } - - test("lag with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) - } - test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -136,197 +88,234 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } - test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) - } - - test("aggregation and range between") { - val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), - Row(2.0d), Row(2.0d))) - } - - test("row between should accept integer values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - - val e = intercept[AnalysisException]( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) - assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) - } - - test("range between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) + test("corr, covar_pop, stddev_pop functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") checkAnswer( df.select( $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) - ) - - def dt(date: String): Date = Date.valueOf(date) - - val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), - (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) - .toDF("key", "value") - checkAnswer( - df2.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))), - Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), - Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) - ) - } - - test("range between should accept double values as boundary") { - val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), - (3.3D, "2"), (2.02D, "1"), (100.001D, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") checkAnswer( df.select( $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, lit(2.5D)))), - Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) - ) - } - - test("range between should accept interval values as boundary") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } + + test("collect_list in ascending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") checkAnswer( df.select( $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, - lit(CalendarInterval.fromString("interval 23 days 4 hours"))))), - Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), - Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) - ) - } - - test("aggregation and rows between with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.createOrReplaceTempView("window_table") + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_list in descending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") checkAnswer( df.select( $"key", - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.currentRow)), - last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), - Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), - Row(4, 4, 4, 4))) - } - - test("aggregation and range between with unbounded") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_set in window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", "3"), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20")).toDF("key", "partition", "value") checkAnswer( df.select( $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) - .equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) - .as("avg_key3") - ), - Seq(Row(3, null, 3.0d, 4.0d, 3.0d), - Row(5, false, 4.0d, 5.0d, 5.0d), - Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), - Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), - Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), - Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) + sort_array( + collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "3")), + Row("b", Array("1", "2", "3")), + Row("c", Array("1", "2", "3")), + Row("d", Array("1", "2", "3")), + Row("e", Array("1", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")))) } - test("reverse sliding range frame") { + test("skewness and kurtosis functions in window") { val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) + ("a", "p1", 1.0), + ("b", "p1", 1.0), + ("c", "p1", 2.0), + ("d", "p1", 2.0), + ("e", "p1", 3.0), + ("f", "p1", 3.0), + ("g", "p1", 3.0), + ("h", "p2", 1.0), + ("i", "p2", 2.0), + ("j", "p2", 5.0)).toDF("key", "partition", "value") checkAnswer( df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) + $"key", + skewness("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() + Seq( + Row("a", -0.27238010581457267, -1.506920415224914), + Row("b", -0.27238010581457267, -1.506920415224914), + Row("c", -0.27238010581457267, -1.506920415224914), + Row("d", -0.27238010581457267, -1.506920415224914), + Row("e", -0.27238010581457267, -1.506920415224914), + Row("f", -0.27238010581457267, -1.506920415224914), + Row("g", -0.27238010581457267, -1.506920415224914), + Row("h", 0.5280049792181881, -1.5000000000000013), + Row("i", 0.5280049792181881, -1.5000000000000013), + Row("j", 0.5280049792181881, -1.5000000000000013))) + } + + test("aggregation function on invalid column") { + val df = Seq((1, "1")).toDF("key", "value") + val e = intercept[AnalysisException]( + df.select($"key", count("invalid").over())) + assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]")) } - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) + test("numerical aggregate functions on string column") { + val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + df.select($"key", + var_pop("value1").over(), + variance("value1").over(), + stddev_pop("value1").over(), + stddev("value1").over(), + sum("value1").over(), + mean("value1").over(), + avg("value1").over(), + corr("value1", "value2").over(), + covar_pop("value1", "value2").over(), + covar_samp("value1", "value2").over(), + skewness("value1").over(), + kurtosis("value1").over()), + Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) } test("statistical functions") { @@ -475,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("b", 2, null, null, null, null, null, null))) } + test("last/first on descending ordered window") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, "v"), + ("b", 1, "k"), + ("b", 2, "l"), + ("b", 3, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order".desc) + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, "v", "v", "v", null, null, "x"), + Row("a", 1, "v", "v", "v", "x", "x", "x"), + Row("a", 2, "v", "v", "v", "y", "y", "y"), + Row("a", 3, "v", "v", "v", "z", "z", "z"), + Row("a", 4, "v", "v", "v", "v", "v", "v"), + Row("b", 1, null, null, "l", "k", "k", "k"), + Row("b", 2, null, null, "l", "l", "l", "l"), + Row("b", 3, null, null, null, null, null, null))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d535896723bd..e0f4d2ba685e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -958,12 +958,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDS() val expected = - """+-------+ - || f| - |+-------+ - ||[foo,1]| - ||[bar,2]| - |+-------+ + """+--------+ + || f| + |+--------+ + ||[foo, 1]| + ||[bar, 2]| + |+--------+ |""".stripMargin checkShowString(ds, expected) @@ -1441,8 +1441,27 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) } } + + test("SPARK-23025: Add support for null type in scala reflection") { + val data = Seq(("a", null)) + checkDataset(data.toDS(), data: _*) + } + + test("SPARK-23614: Union produces incorrect results when caching is used") { + val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() + val group1 = cached.groupBy("x").agg(min(col("y")) as "value") + val group2 = cached.groupBy("x").agg(min(col("z")) as "value") + checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) + } + + test("SPARK-23835: null primitive data type should throw NullPointerException") { + val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() + intercept[NullPointerException](ds.as[(Int, Int)].collect()) + } } +case class TestDataUnion(x: Int, y: Int, z: Int) + case class SingleData(id: Int) case class DoubleData(id: Int, val1: String) case class TripleData(id: Int, val1: String, val2: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf..237412aa692e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -327,6 +328,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { @@ -689,4 +697,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala new file mode 100644 index 000000000000..06303099f531 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.FileNotFoundException + +import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + private val nameWithSpecialChars = "sp&cial%c hars" + + allFileBasedDataSources.foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempPath { dir => + Seq("str").toDS().limit(0).write.format(format).save(dir.getCanonicalPath) + } + } + } + + // `TEXT` data source always has a single column whose name is `value`. + allFileBasedDataSources.filterNot(_ == "text").foreach { format => + test(s"SPARK-23072 Write and read back unicode column names - $format") { + withTempPath { path => + val dir = path.getCanonicalPath + + // scalastyle:off nonascii + val df = Seq("a").toDF("한글") + // scalastyle:on nonascii + + df.write.format(format).option("header", "true").save(dir) + val answerDf = spark.read.format(format).option("header", "true").load(dir) + + assert(df.schema.sameType(answerDf.schema)) + checkAnswer(df, answerDf) + } + } + } + + // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. + // `TEXT` data source always has a single column whose name is `value`. + Seq("orc", "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-empty schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF().limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + Seq("orc", "parquet").foreach { format => + test(s"SPARK-23271 empty RDD when saved should write a metadata only file - $format") { + withTempPath { outputPath => + val df = spark.emptyDataFrame.select(lit(1).as("i")) + df.write.format(format).save(outputPath.toString) + val partFiles = outputPath.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + // Now read the file. + val df1 = spark.read.format(format).load(outputPath.toString) + checkAnswer(df1, Seq.empty[Row]) + assert(df1.schema.equals(df.schema.asNullable)) + } + } + } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-23372 error while writing empty schema files using $format") { + withTempPath { outputPath => + val errMsg = intercept[AnalysisException] { + spark.emptyDataFrame.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + + // Nested empty schema + withTempPath { outputPath => + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", StructType(Nil)), + StructField("c", IntegerType) + )) + val df = spark.createDataFrame(sparkContext.emptyRDD[Row], schema) + val errMsg = intercept[AnalysisException] { + df.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + } + } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-22146 read files containing special characters using $format") { + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } + + // Separate test case for formats that support multiLine as an option. + Seq("json", "csv").foreach { format => + test("SPARK-23148 read files containing special characters " + + s"using $format with multiline enabled") { + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val reader = spark.read.format(format).option("multiLine", true) + val fileContent = reader.load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } + + allFileBasedDataSources.foreach { format => + testQuietly(s"Enabling/disabling ignoreMissingFiles using $format") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) + Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + + val thirdPath = new Path(basePath, "third") + val fs = thirdPath.getFileSystem(spark.sessionState.newHadoopConf()) + Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) + + val df = spark.read.format(format).load( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + // Make sure all data files are deleted and can't be opened. + files.foreach(f => fs.delete(f, false)) + assert(fs.delete(thirdPath, true)) + for (f <- files) { + intercept[FileNotFoundException](fs.open(f)) + } + + checkAnswer(df, Seq(Row("0"), Row("1"))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala new file mode 100644 index 000000000000..c6dd7dadc9d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + conf.set("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("FilterPushdownBenchmark") + .config(conf) + .getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("id", monotonically_increasing_id()) + + val dirORC = dir.getCanonicalPath + "/orc" + val dirParquet = dir.getCanonicalPath + "/parquet" + + df.write.mode("overwrite").orc(dirORC) + df.write.mode("overwrite").parquet(dirParquet) + + spark.read.orc(dirORC).createOrReplaceTempView("orcTable") + spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X + Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X + Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X + Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X + + Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X + Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X + Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X + Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X + + Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X + Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X + Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X + Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X + + Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X + Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X + Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X + + Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X + Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X + Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X + Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X + + Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X + Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X + Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X + + Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X + Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X + Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X + Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X + + Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X + Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X + Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X + Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X + + Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X + Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X + Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X + Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X + + Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X + Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X + Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X + Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X + + Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X + Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X + Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X + Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X + + Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X + Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X + Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X + Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X + */ + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + val mid = numRows / 2 + + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width) + + Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => + val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"id = $mid", + s"id <=> $mid", + s"$mid <= id AND id <= $mid", + s"${mid - 1} < id AND id < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% rows (id < ${numRows * percent / 100})", + s"id < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63a..8fa747465cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade5..640affc10ee5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.File -import java.math.MathContext import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -28,8 +27,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1519,24 +1516,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) - } - test("SPARK-10215 Div of Decimal returns null") { val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") @@ -1585,36 +1564,38 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = - sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) - }.getMessage - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) + } } } @@ -1636,6 +1617,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23281: verify the correctness of sort direction on composite order by clause") { + withTempView("src") { + Seq[(Integer, Integer)]( + (1, 1), + (1, 3), + (2, 3), + (3, 3), + (4, null), + (5, null) + ).toDF("key", "value").createOrReplaceTempView("src") + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key + """.stripMargin), + Seq(Row(3, 1), Row(3, 2), Row(3, 3), Row(null, 4), Row(null, 5))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key desc + """.stripMargin), + Seq(Row(3, 3), Row(3, 2), Row(3, 1), Row(null, 5), Row(null, 4))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value asc, key desc + """.stripMargin), + Seq(Row(null, 5), Row(null, 4), Row(3, 3), Row(3, 2), Row(3, 1))) + } + } + test("run sql directly on files") { val df = spark.range(100).toDF() withTempPath(f => { @@ -1916,12 +1937,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { var e = intercept[AnalysisException] { sql("SELECT a.* FROM temp_table_no_cols a") }.getMessage - assert(e.contains("cannot resolve 'a.*' give input columns ''")) + assert(e.contains("cannot resolve 'a.*' given input columns ''")) e = intercept[AnalysisException] { dfNoCols.select($"b.*") }.getMessage - assert(e.contains("cannot resolve 'b.*' give input columns ''")) + assert(e.contains("cannot resolve 'b.*' given input columns ''")) } } @@ -2129,7 +2150,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { - sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + val provider = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE tbl(i INT, j STRING) USING $provider") checkAnswer(sql("SELECT i, j FROM tbl"), Nil) Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") @@ -2453,9 +2475,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { withTempDir { dir => - val parquetDir = new File(dir, "parquet").getCanonicalPath - spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) - spark.read.parquet(parquetDir) + val dataDir = new File(dir, "data").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(dataDir) + spark.read.load(dataDir) } } @@ -2719,6 +2741,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23079: constraints should be inferred correctly with aliases") { + withTable("t") { + spark.range(5).write.saveAsTable("t") + val t = spark.read.table("t") + val left = t.withColumn("xid", $"id" + lit(1)).as("x") + val right = t.withColumnRenamed("id", "xid").as("y") + val df = left.join(right, "xid").filter("id = 3").toDF() + checkAnswer(df, Row(4, 3)) + } + } + test("SRARK-22266: the same aggregate function was calculated multiple times") { val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" val df = sql(query) @@ -2759,20 +2792,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - // Only New OrcFileFormat supports this - Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, - "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { - withTempPath { file => - val path = file.getCanonicalPath - val emptyDf = Seq((true, 1, "str")).toDF.limit(0) - emptyDf.write.format(format).save(path) - - val df = spark.read.format(format).load(path) - assert(df.schema.sameType(emptyDf.schema)) - checkAnswer(df, emptyDf) - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e3901af4b998..beac9699585d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") - .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) @@ -291,7 +292,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index c01666770720..7d1366092d1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll -import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite @@ -28,8 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -class SessionStateSuite extends SparkFunSuite - with BeforeAndAfterEach with BeforeAndAfterAll { +class SessionStateSuite extends SparkFunSuite { /** * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this @@ -39,6 +36,7 @@ class SessionStateSuite extends SparkFunSuite protected var activeSession: SparkSession = _ override def beforeAll(): Unit = { + super.beforeAll() activeSession = SparkSession.builder().master("local").getOrCreate() } @@ -46,6 +44,8 @@ class SessionStateSuite extends SparkFunSuite if (activeSession != null) { activeSession.stop() activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index c0301f2ce2d6..44bf8624a6bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(SparkSession.builder().getOrCreate() == session) } + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[IllegalStateException](SparkSession.active) + session.stop() + } + test("config options are propagated to existing SparkSession") { val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e79853205..60fa951e2317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -46,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } @@ -95,7 +99,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetchedStats2.get.sizeInBytes == 0) val expectedColStat = - "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + "key" -> CatalogColumnStat(Some(0), None, None, Some(0), + Some(IntegerType.defaultSize), Some(IntegerType.defaultSize)) // There won't be histogram for empty column. Seq("true", "false").foreach { histogramEnabled => @@ -156,7 +161,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(stats, statsWithHgms).foreach { s => s.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + val roundtrip = CatalogColumnStat.fromMap("table_is_foo", field.name, v.toMap(k)) assert(roundtrip == Some(v)) } } @@ -187,7 +192,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared }.mkString(", ")) val expectedColStats = dataTypes.map { case (tpe, idx) => - (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + (s"col$idx", CatalogColumnStat(Some(0), None, None, Some(1), + Some(tpe.defaultSize.toLong), Some(tpe.defaultSize.toLong))) } // There won't be histograms for null columns. @@ -239,6 +245,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier(table))) Seq(false, true).foreach { autoUpdate => withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { @@ -266,6 +273,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + // SPARK-19724: clean up the previous table location. + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) } } } @@ -372,4 +382,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 65ccc1915882..bf4abb6e625c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -67,18 +67,21 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) + "cbool" -> CatalogColumnStat(Some(2), Some("false"), Some("true"), Some(1), Some(1), Some(1)), + "cbyte" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(1), Some(1), Some(1)), + "cshort" -> CatalogColumnStat(Some(2), Some("1"), Some("3"), Some(1), Some(2), Some(2)), + "cint" -> CatalogColumnStat(Some(2), Some("1"), Some("4"), Some(1), Some(4), Some(4)), + "clong" -> CatalogColumnStat(Some(2), Some("1"), Some("5"), Some(1), Some(8), Some(8)), + "cdouble" -> CatalogColumnStat(Some(2), Some("1.0"), Some("6.0"), Some(1), Some(8), Some(8)), + "cfloat" -> CatalogColumnStat(Some(2), Some("1.0"), Some("7.0"), Some(1), Some(4), Some(4)), + "cdecimal" -> CatalogColumnStat(Some(2), Some(dec1.toString), Some(dec2.toString), Some(1), + Some(16), Some(16)), + "cstring" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cbinary" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cdate" -> CatalogColumnStat(Some(2), Some(d1.toString), Some(d2.toString), Some(1), Some(4), + Some(4)), + "ctimestamp" -> CatalogColumnStat(Some(2), Some(t1.toString), Some(t2.toString), Some(1), + Some(8), Some(8)) ) /** @@ -110,6 +113,110 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats } + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { @@ -151,7 +258,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils */ def checkColStats( df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) @@ -161,14 +268,24 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats.keys.mkString(", ")) // Validate statistics - val table = getCatalogTable(tableName) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } + validateColStats(tableName, colStats) + } + } + + /** + * Validate if the given catalog table has the provided statistics. + */ + def validateColStats( + tableName: String, + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { + + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) } } } @@ -215,12 +332,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val emptyColStat = ColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) + val emptyCatalogColStat = CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) // Check catalog statistics assert(catalogTable.stats.isDefined) assert(catalogTable.stats.get.sizeInBytes == 0) assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyCatalogColStat)) // Check relation statistics withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 8673dc14f759..acef62d81ee1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -950,4 +950,24 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(join.duplicateResolved) assert(optimizedPlan.resolved) } + + test("SPARK-23316: AnalysisException after max iteration reached for IN query") { + // before the fix this would throw AnalysisException + spark.range(10).where("(id,id) in (select id, null from range(3))").count + } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index 1a584187a06e..bc95b4696190 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -62,7 +62,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, - |`c_email_address` STRING, `c_last_review_date` STRING) + |`c_email_address` STRING, `c_last_review_date` INT) |USING parquet """.stripMargin) @@ -88,7 +88,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `date_dim` ( - |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_date_sk` INT, `d_date_id` STRING, `d_date` DATE, |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, @@ -115,8 +115,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ - |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, - |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` DATE, + |`i_rec_end_date` DATE, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, @@ -139,8 +139,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `store` ( - |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, - |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` DATE, + |`s_rec_end_date` DATE, `s_closed_date_sk` INT, `s_store_name` STRING, |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, @@ -157,7 +157,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, - |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_quantity` INT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), @@ -225,7 +225,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, - |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` INT, |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), @@ -244,7 +244,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, - |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |`web_country` STRING, `web_gmt_offset` DECIMAL(5,2), `web_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) @@ -315,6 +315,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { """.stripMargin) } + // The TPCDS queries below are based on v1.4 val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -339,6 +340,25 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { } } + // This list only includes TPCDS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7_0 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + + tpcdsQueriesV2_7_0.foreach { name => + val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"$name-v2.7") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } + } + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries val modifiedTPCDSQueries = Seq( "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala new file mode 100644 index 000000000000..d2a6358ee822 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} + +/** + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. + */ +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7..21afdc7e2a33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -79,7 +80,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function substr")) + assert(e.getMessage.contains("Invalid number of arguments for function substr. Expected:")) } test("error reporting for incorrect number of arguments - udf") { @@ -88,7 +89,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function foo")) + assert(e.getMessage.contains("Invalid number of arguments for function foo. Expected:")) } test("error reporting for undefined functions") { @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { @@ -136,73 +144,81 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a WHERE") { - spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + withTempView("integerData") { + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.createOrReplaceTempView("integerData") + val df = sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("integerData") - val result = - sql("SELECT * FROM integerData WHERE oneArgFilter(key)") - assert(result.count() === 20) + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } } test("UDF in a HAVING") { - spark.udf.register("havingFilter", (n: Long) => { n > 5 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT g, SUM(v) as s - | FROM groupData - | GROUP BY g - | HAVING havingFilter(s) - """.stripMargin) - - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } } test("UDF in a GROUP BY") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT SUM(v) - | FROM groupData - | GROUP BY groupFunction(v) - """.stripMargin) - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } } test("UDFs everywhere") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) - spark.udf.register("whereFilter", (n: Int) => { n < 150 }) - spark.udf.register("timesHundred", (n: Long) => { n * 100 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT timesHundred(SUM(v)) as v100 - | FROM groupData - | WHERE whereFilter(v) - | GROUP BY groupFunction(v) - | HAVING havingFilter(v100) - """.stripMargin) - assert(result.count() === 1) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } } test("struct UDF") { @@ -296,4 +312,16 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } + + test("SPARK-23666 Do not display exprId in argument names") { + withTempView("x") { + Seq(((1, 2), 3)).toDF("a", "b").createOrReplaceTempView("x") + spark.udf.register("f", (a: Int) => a) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + spark.sql("SELECT f(a._1) FROM x").show + } + assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a08433ba794d..cc8b600efa46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -44,6 +44,8 @@ object UDT { case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } + + override def toString: String = data.mkString("(", ", ", ")") } private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { @@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest + with ExpressionEvalHelper { import testImplicits._ private lazy val pointsRDD = Seq( @@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT pointsRDD.except(pointsRDD2), Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } + + test("SPARK-23054 Cast UserDefinedType to string") { + val udt = new UDT.MyDenseVectorUDT() + val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val data = udt.serialize(vector) + val ret = Cast(Literal(data, udt), StringType, None) + checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index aac8d56ba620..bde2de5b39fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -101,4 +104,32 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } + + test("SPARK-23207: Make repartition() generate consistent output") { + def assertConsistency(ds: Dataset[java.lang.Long]): Unit = { + ds.persist() + + val exchange = ds.mapPartitions { iter => + Random.shuffle(iter) + }.repartition(111) + val exchange2 = ds.repartition(111) + + assert(exchange.rdd.collectPartitions() === exchange2.rdd.collectPartitions()) + } + + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") { + // repartition() should generate consistent output. + assertConsistency(spark.range(10000)) + + // case when input contains duplicated rows. + assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) + } + } + + test("SPARK-23614: Fix incorrect reuse exchange when caching is used") { + val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache() + val projection1 = cached.select("_1", "_2").queryExecution.executedPlan + val projection2 = cached.select("_1", "_3").queryExecution.executedPlan + assert(!projection1.sameResult(projection2)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index cc943e0356f2..972b47e96fe0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -36,7 +36,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("basic semantic") { val expectedErrorMsg = "not found" - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, @@ -79,19 +79,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) - } finally { - spark.catalog.dropGlobalTempView("src") } } test("global temp view is shared among all sessions") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) val newSession = spark.newSession() checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) - } finally { - spark.catalog.dropGlobalTempView("src") } } @@ -105,27 +101,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE GLOBAL TEMP VIEW USING") { withTempPath { path => - try { + withGlobalTempView("src") { Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.toURI}')") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) - } finally { - spark.catalog.dropGlobalTempView("src") } } } test("CREATE TABLE LIKE should work for global temp view") { - try { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") - val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) - assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) - } finally { - spark.catalog.dropGlobalTempView("src") - sql("DROP TABLE default.cloned") + withTable("cloned") { + withGlobalTempView("src") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType() + .add("a", "int", false).add("b", "string", false)) + } } } @@ -140,32 +134,31 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { - spark.catalog.dropTempView("v1") - spark.catalog.dropGlobalTempView("v2") + spark.catalog.dropGlobalTempView("v1") + spark.catalog.dropTempView("v2") } } test("should lookup global temp view if and only if global temp db is specified") { - try { - sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") - sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + withTempView("same_name") { + withGlobalTempView("same_name") { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") - checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) - // we never lookup global temp views if database is not specified in table name - spark.catalog.dropTempView("same_name") - intercept[AnalysisException](sql("SELECT * FROM same_name")) + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) - // Use qualified name to lookup a global temp view. - checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) - } finally { - spark.catalog.dropTempView("same_name") - spark.catalog.dropGlobalTempView("same_name") + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } } } test("public Catalog should recognize global temp view") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") assert(spark.catalog.tableExists(globalTempDB, "src")) @@ -175,8 +168,6 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { description = null, tableType = "TEMPORARY", isTemporary = true).toString) - } finally { - spark.catalog.dropGlobalTempView("src") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 78c1e5dae566..a543eb835165 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.io.File + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SharedSQLContext class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { @@ -125,4 +128,23 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() } } + + test("Incorrect result caused by the rule OptimizeMetadataOnlyQuery") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + withTempPath { path => + val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") + Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") + .write.json(tablePath.getCanonicalPath) + + val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() + checkAnswer(df, Row("a", "e", "c")) + + val localRelation = df.queryExecution.optimizedPlan.collectFirst { + case l: LocalRelation => l + } + assert(localRelation.nonEmpty, "expect to see a LocalRelation") + assert(localRelation.get.output.map(_.name) == Seq("cOl3", "cOl1", "cOl5")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b50642d275ba..f0dfe6b76f7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -194,7 +194,20 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) + } + + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).reverse) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) } test("PartitioningCollection") { @@ -260,11 +273,16 @@ class PlannerSuite extends SharedSQLContext { // do they satisfy the distribution requirements? As a result, we need at least four test cases. private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { - if (outputPlan.children.length > 1 - && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { - val childPartitionings = outputPlan.children.map(_.outputPartitioning) - if (!Partitioning.allCompatible(childPartitionings)) { - fail(s"Partitionings are not compatible: $childPartitionings") + if (outputPlan.children.length > 1) { + val childPartitionings = outputPlan.children.zip(outputPlan.requiredChildDistribution) + .filter { + case (_, UnspecifiedDistribution) => false + case (_, _: BroadcastDistribution) => false + case _ => true + }.map(_._1.outputPartitioning) + + if (childPartitionings.map(_.numPartitions).toSet.size > 1) { + fail(s"Partitionings doesn't have same number of partitions: $childPartitionings") } } outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { @@ -274,40 +292,7 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { - // Consider an operator that requires inputs that are clustered by two expressions (e.g. - // sort merge join where there are multiple columns in the equi-join condition) - val clusteringA = Literal(1) :: Nil - val clusteringB = Literal(2) :: Nil - val distribution = ClusteredDistribution(clusteringA ++ clusteringB) - // Say that the left and right inputs are each partitioned by _one_ of the two join columns: - val leftPartitioning = HashPartitioning(clusteringA, 1) - val rightPartitioning = HashPartitioning(clusteringB, 1) - // Individually, each input's partitioning satisfies the clustering distribution: - assert(leftPartitioning.satisfies(distribution)) - assert(rightPartitioning.satisfies(distribution)) - // However, these partitionings are not compatible with each other, so we still need to - // repartition both inputs prior to performing the join: - assert(!leftPartitioning.compatibleWith(rightPartitioning)) - assert(!rightPartitioning.compatibleWith(leftPartitioning)) - val inputPlan = DummySparkPlan( - children = Seq( - DummySparkPlan(outputPartitioning = leftPartitioning), - DummySparkPlan(outputPartitioning = rightPartitioning) - ), - requiredChildDistribution = Seq(distribution, distribution), - requiredChildOrdering = Seq(Seq.empty, Seq.empty) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { - fail(s"Exchange should have been added:\n$outputPlan") - } - } - test("EnsureRequirements with child partitionings with different numbers of output partitions") { - // This is similar to the previous test, except it checks that partitionings are not compatible - // unless they produce the same number of partitions. val clustering = Literal(1) :: Nil val distribution = ClusteredDistribution(clustering) val inputPlan = DummySparkPlan( @@ -386,18 +371,15 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + test("EnsureRequirements eliminates Exchange if child has same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { @@ -407,17 +389,13 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements does not eliminate Exchange with different partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - // Number of partitions differ - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 08a4a21b20f6..8269d4d3a285 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -53,15 +53,17 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a temp view on a permanent view") { - withView("jtv1", "temp_jtv1") { - sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") - checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + withView("jtv1") { + withTempView("temp_jtv1") { + sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") + checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + } } } test("create a temp view on a temp view") { - withView("temp_jtv1", "temp_jtv2") { + withTempView("temp_jtv1", "temp_jtv2") { sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") sql("CREATE TEMPORARY VIEW temp_jtv2 AS SELECT * FROM temp_jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM temp_jtv2"), Row(2)) @@ -69,21 +71,25 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a permanent view on a temp view") { - withView("jtv1", "temp_jtv1", "global_temp_jtv1") { - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") - var e = intercept[AnalysisException] { - sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `jtv1` by " + - "referencing a temporary view `temp_jtv1`")) - - val globalTempDB = spark.sharedState.globalTempViewManager.database - sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") - e = intercept[AnalysisException] { - sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + - s"a temporary view `global_temp`.`global_temp_jtv1`")) + withView("jtv1") { + withTempView("temp_jtv1") { + withGlobalTempView("global_temp_jtv1") { + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + var e = intercept[AnalysisException] { + sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `jtv1` by " + + "referencing a temporary view `temp_jtv1`")) + + val globalTempDB = spark.sharedState.globalTempViewManager.database + sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") + e = intercept[AnalysisException] { + sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + + s"a temporary view `global_temp`.`global_temp_jtv1`")) + } + } } } @@ -218,10 +224,12 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("error handling: disallow IF NOT EXISTS for CREATE TEMPORARY VIEW") { - val e = intercept[AnalysisException] { - sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + withTempView("myabcdview") { + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + } + assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } - assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } test("error handling: fail if the temp view sql itself is invalid") { @@ -270,7 +278,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("correctly parse CREATE TEMPORARY VIEW statement") { - withView("testView") { + withTempView("testView") { sql( """CREATE TEMPORARY VIEW |testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') @@ -282,28 +290,32 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("should NOT allow CREATE TEMPORARY VIEW when TEMPORARY VIEW with same name exists") { - withView("testView") { + withTempView("testView") { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") val e = intercept[AnalysisException] { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") } - assert(e.message.contains("Temporary table") && e.message.contains("already exists")) + assert(e.message.contains("Temporary view") && e.message.contains("already exists")) } } test("should allow CREATE TEMPORARY VIEW when a permanent VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE VIEW testView AS SELECT id FROM jt") - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE VIEW testView AS SELECT id FROM jt") + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + } } } test("should allow CREATE permanent VIEW when a TEMPORARY VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + } } } @@ -659,7 +671,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + - "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + "of spark.sql.view.maxNestedViewDepth to work around this.")) } val e = intercept[IllegalArgumentException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae799..3e31d22e15c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d..bf588d3bb784 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.map.BytesToBytesMap /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -116,6 +117,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, @@ -204,4 +206,42 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } + + test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") { + val memoryManager = new TestMemoryManager(new SparkConf()) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes()) + + // Key/value are a unsafe rows with a single int column + val schema = new StructType().add("i", IntegerType) + val key = new UnsafeRow(1) + key.pointTo(new Array[Byte](32), 32) + key.setInt(0, 1) + val value = new UnsafeRow(1) + value.pointTo(new Array[Byte](32), 32) + value.setInt(0, 2) + + for (_ <- 1 to 65) { + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + + // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` + // which has duplicated keys and the number of entries exceeds its capacity. + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + new UnsafeKVExternalSorter( + schema, + schema, + sparkContext.env.blockManager, + sparkContext.env.serializerManager, + taskMemoryManager.pageSizeBytes(), + Int.MaxValue, + map) + } finally { + TaskContext.unset() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9..a3ae93810aa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c4..9180a22c260f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -24,13 +25,15 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan @@ -121,31 +124,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ - val dsInt = spark.range(3).cache - dsInt.count + val dsInt = spark.range(3).cache() + dsInt.count() val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined - ) + assert(planInt.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache - dsString.count + val dsString = spark.range(3).map(_.toString).cache() + dsString.count() val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec]).isDefined - ) + assert(planString.collect { + case i: InMemoryTableScanExec if !i.supportsBatch => () + }.length == 1) assert(dsStringFilter.collect() === Array("1")) } @@ -209,16 +204,16 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21871 check if we can get large code size when compiling too long functions") { + ignore("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) - val codeWithLongFunctions = genGroupByCode(20) + val codeWithLongFunctions = genGroupByCode(50) val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } - test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + ignore("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath @@ -236,4 +231,92 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("Control splitting consume function by operators with config") { + import testImplicits._ + val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) + + Seq(true, false).foreach { config => + withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == config) + } + } + } + + test("Skip splitting consume function when parameter number exceeds JVM limit") { + // since every field is nullable we have 2 params for each input column (one for the value + // and one for the isNull variable) + Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) => + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*) + .write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", + SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { + val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") + val df = spark.read.parquet(path).selectExpr(projection: _*) + + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find { + case _: WholeStageCodegenExec => true + case _ => false + } + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == hasSplit) + } + } + } + } + + test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") { + // test case adapted from DataFrameSuite to trigger ReuseExchange + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { + val df = spark.range(100) + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined, + "codegen stage IDs should be preserved through ReuseExchange") + checkAnswer(join, df.toDF) + } + } + + test("including codegen stage ID in generated class name should not regress codegen caching") { + import testImplicits._ + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { + val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE + + // the same query run twice should hit the codegen cache + spark.range(3).select('id + 2).collect + val after1 = bytecodeSizeHisto.getCount + spark.range(3).select('id + 2).collect + val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately + // bytecodeSizeHisto's count is always monotonically increasing if new compilation to + // bytecode had occurred. If the count stayed the same that means we've got a cache hit. + assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected") + + // a different query can result in codegen cache miss, that's by design + } + } + + ignore("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") + for (i <- 0 until 70) { + df = df.groupBy("name").agg(avg("age").alias("age")) + } + assert(df.limit(1).collect() === Array(Row("bat", 8.0))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bed..3fad7dfddadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92..92506032ab2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { @@ -217,21 +217,21 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct0 = reader.getStruct(0, 2) + val struct0 = reader.getStruct(0) assert(struct0.getInt(0) === 1) assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) - val struct1 = reader.getStruct(1, 2) + val struct1 = reader.getStruct(1) assert(struct1.isNullAt(0)) assert(struct1.isNullAt(1)) assert(reader.isNullAt(2)) - val struct3 = reader.getStruct(3, 2) + val struct3 = reader.getStruct(3) assert(struct3.getInt(0) === 4) assert(struct3.isNullAt(1)) - val struct4 = reader.getStruct(4, 2) + val struct4 = reader.getStruct(4) assert(struct4.isNullAt(0)) assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) @@ -252,15 +252,15 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + val struct00 = reader.getStruct(0).getStruct(0, 2) assert(struct00.getInt(0) === 1) assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) - val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + val struct10 = reader.getStruct(1).getStruct(0, 2) assert(struct10.isNullAt(0)) assert(struct10.isNullAt(1)) - val struct2 = reader.getStruct(2, 1) + val struct2 = reader.getStruct(2) assert(struct2.isNullAt(0)) assert(reader.isNullAt(3)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d..470b93efd197 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9a..abe61a2c2b9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index ff7c5e58e986..863703b15f4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ @@ -42,10 +43,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, - data.logicalPlan.stats) + data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) } @@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val logicalPlan = testData.select('value, 'key).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - logicalPlan.stats) + logicalPlan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -329,14 +330,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan) // Materialize the data. val expectedAnswer = data.collect() checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { @@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, + LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) @@ -477,7 +479,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) val execPlan = if (enabled == "true") { - WholeStageCodegenExec(planBeforeFilter.head) + WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) } else { planBeforeFilter.head } @@ -487,7 +489,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { - withSQLConf("spark.sql.cbo.enabled" -> "true") { + // This test case depends on the size of parquet in statistics. + withSQLConf( + SQLConf.CBO_ENABLED.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "parquet") { withTempPath { workDir => withTable("table1") { val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 9005ec93e786..619b76fabdd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -77,7 +77,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -101,7 +101,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala new file mode 100644 index 000000000000..f3e15189a641 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics + +class CommandUtilsSuite extends SparkFunSuite { + + test("Check if compareAndGetNewStats returns correct results") { + val oldStats1 = CatalogStatistics(sizeInBytes = 10, rowCount = Some(100)) + val newStats1 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 10, newRowCount = Some(100)) + assert(newStats1.isEmpty) + val newStats2 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = None) + assert(newStats2.isEmpty) + val newStats3 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 20, newRowCount = Some(-1)) + assert(newStats3.isDefined) + newStats3.foreach { stat => + assert(stat.sizeInBytes === 20) + assert(stat.rowCount.isEmpty) + } + val newStats4 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = Some(200)) + assert(newStats4.isDefined) + newStats4.foreach { stat => + assert(stat.sizeInBytes === 10) + assert(stat.rowCount.isDefined && stat.rowCount.get === 200) + } + } + + test("Check if compareAndGetNewStats can handle large values") { + // Tests for large values + val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) + val newStats5 = CommandUtils.compareAndGetNewStats( + Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) + assert(newStats5.isEmpty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b60..e0ccae15f1d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -229,7 +236,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = false) + isTemp = true, ignoreIfExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -237,7 +244,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = false) + isTemp = false, ignoreIfExists = false, replace = false) val expected3 = CreateFunctionCommand( None, "helloworld3", @@ -245,7 +252,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = true) + isTemp = true, ignoreIfExists = false, replace = true) val expected4 = CreateFunctionCommand( Some("hello"), "world1", @@ -253,7 +260,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = true) + isTemp = false, ignoreIfExists = false, replace = true) val expected5 = CreateFunctionCommand( Some("hello"), "world2", @@ -261,7 +268,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = true, replace = false) + isTemp = false, ignoreIfExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9c..3998ceca38b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test @@ -132,6 +134,37 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) } } + + // TODO: This test is copied from HiveDDLSuite, unify it later. + test("SPARK-23348: append data to data source table with saveAsTable") { + withTable("t", "t1") { + Seq(1 -> "a").toDF("i", "j").write.saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a")) + + sql("INSERT INTO t SELECT 2, 'b'") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + + Seq(3 -> "c").toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Nil) + + Seq("c" -> 3).toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") + :: Row(null, "3") :: Nil) + + Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") + + val e = intercept[AnalysisException] { + val format = if (spark.sessionState.conf.defaultDataSourceName.equalsIgnoreCase("json")) { + "orc" + } else { + "json" + } + Seq(5 -> "e").toDF("i", "j").write.mode("append").format(format).saveAsTable("t1") + } + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format")) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils { @@ -147,6 +180,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private val escapedIdentifier = "`(.+)`".r + private def dataSource: String = { + if (isUsingHiveMetastore) { + "HIVE" + } else { + "PARQUET" + } + } protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { @@ -332,6 +372,75 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("CTAS a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + sql("INSERT INTO tab1 VALUES (1, 'a')") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing non-empty directory") { + withTable("tab1") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + // create an empty hidden file + tableLoc.mkdir() + val hiddenGarbageFile = new File(tableLoc.getCanonicalPath, ".garbage") + hiddenGarbageFile.createNewFile() + val exMsg = "Can not create the managed table('`tab1`'). The associated location" + val exMsgWithDefaultDB = + "Can not create the managed table('`default`.`tab1`'). The associated location" + var ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + }.getMessage + if (isUsingHiveMetastore) { + assert(ex.contains(exMsgWithDefaultDB)) + } else { + assert(ex.contains(exMsg)) + } + + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + + // Always check location of managed table, with or without (IF NOT EXISTS) + withTable("tab2") { + sql(s"CREATE TABLE tab2 (col1 int, col2 string) USING ${dataSource}") + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE IF NOT EXISTS tab1 LIKE tab2") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], @@ -508,6 +617,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create table - append to a non-partitioned table created with different paths") { + import testImplicits._ + withTempDir { dir1 => + withTempDir { dir2 => + withTable("path_test") { + Seq(1L -> "a").toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir1.getCanonicalPath) + .saveAsTable("path_test") + + val ex = intercept[AnalysisException] { + Seq((3L, "c")).toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir2.getCanonicalPath) + .saveAsTable("path_test") + }.getMessage + assert(ex.contains("The location of the existing table `default`.`path_test`")) + + checkAnswer( + spark.table("path_test"), Row(1L, "a") :: Nil) + } + } + } + } + test("Refresh table after changing the data source table partitioning") { import testImplicits._ @@ -739,7 +877,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it // to a temp file by withResourceTempPath withResourceTempPath("test-data/cars.csv") { tmpFile => - withView("testview") { + withTempView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + s"OPTIONS (PATH '${tmpFile.toURI}')") @@ -835,6 +973,31 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") { + withTempView("view1") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("view1"))) + } + } + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") @@ -883,6 +1046,42 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") { + withTempView("view1", "view2") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + sql( + """ + |CREATE TEMPORARY VIEW view2 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO view2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`view2`': destination table already exists")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == + Seq(TableIdentifier("view1"), TableIdentifier("view2"))) + } + } + test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -934,7 +1133,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file @@ -991,6 +1190,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("SHOW DATABASES LIKE '*db1A'"), Row("showdb1a") :: Nil) + checkAnswer( + sql("SHOW DATABASES '*db1A'"), + Row("showdb1a") :: Nil) + checkAnswer( sql("SHOW DATABASES LIKE 'showdb1A'"), Row("showdb1a") :: Nil) @@ -1475,6 +1678,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Ensure that change column will preserve other metadata fields. sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") assert(getMetadata("col1").getString("key") == "value") + assert(getMetadata("col1").getString("comment") == "this is col1") } test("drop build-in function") { @@ -1538,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + Row("Usage: concat(col1, col2, ..., colN) - " + + "Returns the concatenation of col1, col2, ..., colN.") :: Nil ) // extended mode checkAnswer( @@ -1724,12 +1928,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("block creating duplicate temp table") { - withView("t_temp") { + withTempView("t_temp") { sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") val e = intercept[TempTableAlreadyExistsException] { sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") }.getMessage - assert(e.contains("Temporary table 't_temp' already exists")) + assert(e.contains("Temporary view 't_temp' already exists")) + } + } + + test("block creating duplicate temp view") { + withTempView("t_temp") { + sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") + val e = intercept[TempTableAlreadyExistsException] { + sql("CREATE TEMPORARY VIEW t_temp (c3 int, c4 string) USING JSON") + }.getMessage + assert(e.contains("Temporary view 't_temp' already exists")) } } @@ -1971,8 +2185,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index b4616826e40b..18bb4bfe661c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -59,7 +59,7 @@ class FileIndexSuite extends SharedSQLContext { require(!unqualifiedDirPath.toString.contains("file:")) require(!unqualifiedFilePath.toString.contains("file:")) - val fs = unqualifiedDirPath.getFileSystem(sparkContext.hadoopConfiguration) + val fs = unqualifiedDirPath.getFileSystem(spark.sessionState.newHadoopConf()) val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d89..8764f0c42cf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -401,7 +401,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi sparkSession = spark, rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], - partitionSchema = None) + userSpecifiedSchema = None) // This should not fail. fileCatalog.listLeafFiles(Seq(new Path(tempDir))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index caf03885e387..c1f2c18d1417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.io.{File, FilenameFilter} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.test.SharedSQLContext class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { @@ -39,4 +40,44 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } + + test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") { + import testImplicits._ + Seq(1.0, 0.5).foreach { compressionFactor => + withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, + "spark.sql.autoBroadcastJoinThreshold" -> "400") { + withTempPath { workDir => + // the file size is 740 bytes + val workDirPath = workDir.getAbsolutePath + val data1 = Seq(100, 200, 300, 400).toDF("count") + data1.write.parquet(workDirPath + "/data1") + val df1FromFile = spark.read.parquet(workDirPath + "/data1") + val data2 = Seq(100, 200, 300, 400).toDF("count") + data2.write.parquet(workDirPath + "/data2") + val df2FromFile = spark.read.parquet(workDirPath + "/data2") + val joinedDF = df1FromFile.join(df2FromFile, Seq("count")) + if (compressionFactor == 0.5) { + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.nonEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.isEmpty) + } else { + // compressionFactor is 1.0 + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.isEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d921..461abdd96d3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -1279,4 +1278,48 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 000000000000..3e20cc47dca2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala new file mode 100644 index 000000000000..85cf054e51f6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.json + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.util.{Benchmark, Utils} + +/** + * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. + * To run this: + * spark-submit --class --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115..4b3921c61a00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets +import java.io.{File, FileOutputStream, StringWriter} +import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} +import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -27,7 +28,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} @@ -47,6 +48,10 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2063,4 +2068,363 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val data = + s""" + | {"f": + |"a", "f0": 1}$lineSep{"f": + | + |"c", "f0": 2}$lineSep{"f": "d", "f0": 3} + """.stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write.option("lineSep", lineSep).json(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert( + readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).json(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => + testLineSeparator(lineSep) + } + // scalastyle:on nonascii + + test("""SPARK-21289: Support line separator - default value \r, \r\n and \n""") { + val data = + "{\"f\": \"a\", \"f0\": 1}\r{\"f\": \"c\", \"f0\": 2}\r\n{\"f\": \"d\", \"f0\": 3}\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } + + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) + } + + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) + } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + } + } + + assert(exception.getMessage == encoding) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq( + ("a", 1), ("b", 2), ("c", 3)) + ).repartition(2) + val e = intercept[IllegalArgumentException] { + ds.write + .option("encoding", encoding) + .option("multiline", "false") + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + }.getMessage + assert(e.contains( + s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + } + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read + } else { + spark.read.schema(schema) + } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } + + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f0..6e9559edf8ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e00e057a18cc..f58c331f33ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} @@ -531,6 +532,52 @@ abstract class OrcQueryTest extends OrcTest { val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) assert(df.count() == 20) } + + test("Enabling/disabling ignoreCorruptFiles") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val m1 = intercept[SparkException] { + testIgnoreCorruptFiles() + }.getMessage + assert(m1.contains("Could not read footer for file")) + val m2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m2.contains("Malformed ORC file")) + } + } } class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 6f5f2fd795f7..8a3bbd03a26d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -39,8 +39,8 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { protected override def beforeAll(): Unit = { super.beforeAll() - orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - orcTableDir = Utils.createTempDir("orctests", "sparksql") + orcTableAsDir = Utils.createTempDir(namePrefix = "orctests") + orcTableDir = Utils.createTempDir(namePrefix = "orctests") sparkContext .makeRDD(1 to 10) @@ -160,6 +160,15 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + Seq(Seq(Array.empty[Float]).toDF(), Seq(Array.empty[Double]).toDF()).foreach { df => + withTempPath { path => + df.write.format("orc").save(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), df) + } + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index caa4f6d70c6a..f3ecc5ced689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -101,7 +101,7 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils if (check) { result = Some(MarkingFileOutput.checkMarker( destPath, - spark.sparkContext.hadoopConfiguration)) + spark.sessionState.newHadoopConf())) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 000000000000..ed8fd2b45345 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index edb1290ee2eb..db73bfa149aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +67,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +98,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index ccb34355f1ba..3a0867fd2b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -29,7 +29,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo test("read parquet footers in parallel") { def testReadFooters(ignoreCorruptFiles: Boolean): Unit = { withTempDir { dir => - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val basePath = dir.getCanonicalPath val path1 = new Path(basePath, "first") @@ -44,7 +44,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten val footers = ParquetFileFormat.readParquetFootersInParallel( - sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles) + spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles) assert(footers.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 33801954ebd5..667e0b1760e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -54,6 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -76,8 +79,10 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -96,13 +101,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) } checker(stripSparkFilter(query), expected) - } } } @@ -313,6 +317,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - date") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") + + withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate( + Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate( + Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) + checkFilterPredicate( + '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ @@ -473,7 +519,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -481,7 +527,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -489,7 +535,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.Not( sources.And( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 44a8b25c61df..0b3e8ca060d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -653,7 +654,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -662,7 +665,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getInt(0), row.getString(1)) result += v } - assert(data == result) + assert(data.toSet == result.toSet) } finally { reader.close() } @@ -670,7 +673,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -678,7 +683,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val row = reader.getCurrentValue.asInstanceOf[InternalRow] result += row.getString(0) } - assert(data.map(_._2) == result) + assert(data.map(_._2).toSet == result.toSet) } finally { reader.close() } @@ -686,7 +691,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -695,7 +702,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getString(0), row.getInt(1)) result += v } - assert(data.map { x => (x._2, x._1) } == result) + assert(data.map { x => (x._2, x._1) }.toSet == result.toSet) } finally { reader.close() } @@ -703,7 +710,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,8 +751,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = - new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val vectorizedReader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) @@ -771,6 +781,24 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(option.compressionCodecClassName == "UNCOMPRESSED") } } + + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { + withTempPath { file => + val jsonData = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input") + .write.parquet(file.getAbsolutePath) + checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index e3edafa9c84e..9c75965639d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -163,7 +163,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // Just to be defensive in case anything ever changes in parquet, this test checks // the assumption on column stats, and also the end-to-end behavior. - val hadoopConf = sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val fs = FileSystem.get(hadoopConf) val parts = fs.listStatus(new Path(tableDir.getAbsolutePath), new PathFilter { override def accept(path: Path): Boolean = !path.getName.startsWith("_") @@ -184,7 +184,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // when the data is read back as mentioned above, b/c int96 is unsigned. This // assert makes sure this holds even if we change parquet versions (if eg. there // were ever statistics even on unsigned columns). - assert(columnStats.isEmpty) + assert(!columnStats.hasNonNullValue) } // These queries should return the entire dataset with the conversion applied, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index d4902641e335..e887c9734a8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -57,6 +57,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val timeZone = TimeZone.getDefault() val timeZoneId = timeZone.getID + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "parquet") + } + + protected override def afterAll(): Unit = { + spark.conf.unset(SQLConf.DEFAULT_DATA_SOURCE_NAME.key) + super.afterAll() + } + test("column type inference") { def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { assert(inferPartitionColumnValue(raw, true, timeZone) === literal) @@ -1120,4 +1130,18 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Row(3, BigDecimal("2" * 30)) :: Nil) } } + + test("SPARK-23436: invalid Dates should be inferred as String in partition inference") { + withTempPath { path => + val data = Seq(("1", "2018-01", "2018-01-01-04", "test")) + .toDF("id", "date_month", "date_hour", "data") + + data.write.partitionBy("date_month", "date_hour").parquet(path.getAbsolutePath) + val input = spark.read.parquet(path.getAbsolutePath).select("id", + "date_month", "date_hour", "data") + + assert(input.schema.sameType(input.schema)) + checkAnswer(input, data) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4c8c9ef6e043..e1f094d0a7af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -320,54 +320,38 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - checkAnswer( - df, - Seq(Row(0), Row(1))) + checkAnswer(df, Seq(Row(0), Row(1))) } } - withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { - testIgnoreCorruptFiles() - } - - withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { - val exception = intercept[SparkException] { - testIgnoreCorruptFiles() - } - assert(exception.getMessage().contains("is not a Parquet file")) - } - } - - testQuietly("Enabling/disabling ignoreMissingFiles") { - def testIgnoreMissingFiles(): Unit = { + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) - val thirdPath = new Path(basePath, "third") - spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) - val df = spark.read.parquet( + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").parquet( new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) - fs.delete(thirdPath, true) - checkAnswer( - df, - Seq(Row(0), Row(1))) + checkAnswer(df, Seq(Row(0), Row(1))) } } - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { - testIgnoreMissingFiles() + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() } - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { val exception = intercept[SparkException] { - testIgnoreMissingFiles() + testIgnoreCorruptFiles() + } + assert(exception.getMessage().contains("is not a Parquet file")) + val exception2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() } - assert(exception.getMessage().contains("does not exist")) + assert(exception2.getMessage().contains("is not a Parquet file")) } } @@ -835,7 +819,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val path = dir.getCanonicalPath spark.range(3).write.parquet(path) - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val files = fs.listFiles(new Path(path), true) while (files.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 86a3c71a3c4f..e43336d94736 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -76,6 +76,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -96,7 +97,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -119,7 +121,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -262,6 +265,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -279,7 +283,8 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 2cd2a600f2b9..9d3dfae348be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.parquet.io.ParquetDecodingException import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -382,6 +385,58 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + // ======================================= + // Tests for parquet schema mismatch error + // ======================================= + def testSchemaMismatch(path: String, vectorizedReaderEnabled: Boolean): SparkException = { + import testImplicits._ + + var e: SparkException = null + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReaderEnabled.toString) { + // Create two parquet files with different schemas in the same folder + Seq(("bcd", 2)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet(s"$path/parquet") + Seq((1, "abc")).toDF("a", "b").coalesce(1).write.mode("append").parquet(s"$path/parquet") + + e = intercept[SparkException] { + spark.read.parquet(s"$path/parquet").collect() + } + } + e + } + + test("schema mismatch failure error message for parquet reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) + val expectedMessage = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the corresponding " + + "files. Details:" + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[ParquetDecodingException]) + assert(e.getCause.getMessage.startsWith(expectedMessage)) + } + } + + test("schema mismatch failure error message for parquet vectorized reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true) + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + + // Check if the physical type is reporting correctly + val errMsg = e.getCause.getMessage + assert(errMsg.startsWith("Parquet column cannot be converted in file")) + val file = errMsg.substring("Parquet column cannot be converted in file ".length, + errMsg.indexOf(". ")) + val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + assert(col.length == 1) + if (col(0).dataType == StringType) { + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + } else { + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + } + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 33287044f279..0e7f3afa9c3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.TestUtils import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -172,6 +175,45 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-23577: Support line separator - lineSep: '$lineSep'") { + // Read + val values = Seq("a", "b", "\nc") + val data = values.mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, Seq("a", "b", "\nc").toDF()) + } + } + + // Write + withTempPath { path => + values.toDF().coalesce(1) + .write.option("lineSep", lineSep).text(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = values.toDF() + df.write.option("lineSep", lineSep).text(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => + testLineSeparator(lineSep) + } + // scalastyle:on nonascii + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index 8bd736bee69d..fff0f82f9bc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -95,7 +95,7 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { df1.write.option("compression", "gzip").mode("overwrite").text(path) // On reading through wholetext mode, one file will be read as a single row, i.e. not // delimited by "next line" character. - val expected = Row(Range(0, 1000).mkString("", "\n", "\n")) + val expected = Row(df1.collect().map(_.getString(0)).mkString("", "\n", "\n")) Seq(10, 100, 1000).foreach { bytes => withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { val df2 = spark.read.option("wholetext", "true").format("text").load(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6da46ea3480b..bcdee792f4c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { private def testBroadcastJoin[T: ClassTag]( joinType: String, forceBroadcast: Boolean = false): SparkPlan = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") @@ -109,17 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("SPARK-23192: broadcast hint should be retained after using the cached data") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } finally { + spark.catalog.clearCache() + } + } + } + + test("SPARK-23214: cached data should not carry extra hint info") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + broadcast(df2).cache() + + val df3 = df1.join(df2, Seq("key"), "inner") + val numCachedPlan = df3.queryExecution.executedPlan.collect { + case i: InMemoryTableScanExec => i + }.size + // df2 should be cached. + assert(numCachedPlan === 1) + + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + // df2 should not be broadcasted. + assert(numBroadCastHashJoin === 0) + } finally { + spark.catalog.clearCache() + } + } + } + test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) - val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value") val df5 = df4.join(df3, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) @@ -127,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value") val broadcasted = broadcast(df2) - val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") - - val cases = Seq(broadcasted.limit(2), - broadcasted.filter("value < 10"), - broadcasted.sample(true, 0.5), - broadcasted.distinct(), - broadcasted.groupBy("value").agg(min($"key").as("key")), - // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the broadcast hint should be propagated here - broadcasted.except(df3), - broadcasted.intersect(df3)) + val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value") + + val cases = Seq( + broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) cases.foreach(assertBroadcastJoin) } @@ -227,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't change broadcast join buildSide if user clearly specified") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes @@ -279,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't bias towards build right if user didn't specify") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes @@ -318,7 +358,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { case b: BroadcastNestedLoopJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) - case b: BroadcastNestedLoopJoinExec => + case b: BroadcastHashJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 122d28798136..dcc540fc4f10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -21,19 +21,20 @@ import java.io.File import scala.collection.mutable.HashMap +import org.apache.spark.TestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ protected def currentExecutionIds(): Set[Long] = { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) statusStore.executionsList.map(_.executionId).toSet } @@ -91,7 +92,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) .write.format(dataFormat).mode("overwrite").insertInto(tableName) } - assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } } @@ -121,7 +122,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { .mode("overwrite") .insertInto(tableName) } - assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c..25ee95daa034 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val page = new OnHeapMemoryBlock((1<<10) * 8L) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala new file mode 100644 index 000000000000..fe59cb25d500 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +abstract class CheckpointFileManagerTests extends SparkFunSuite { + + def createManager(path: Path): CheckpointFileManager + + test("mkdirs, list, createAtomic, open, delete, exists") { + withTempPath { p => + val basePath = new Path(p.getAbsolutePath) + val fm = createManager(basePath) + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create atomic without overwrite + var path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).close() + assert(fm.exists(path)) + quietly { + intercept[IOException] { + // should throw exception since file exists and overwrite is false + fm.createAtomic(path, overwriteIfPossible = false).close() + } + } + + // Create atomic with overwrite if possible + path = new Path(s"$dir/file2") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() + assert(fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() // should not throw exception + + // Open and delete + fm.open(path).close() + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + } + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} + +class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("CheckpointFileManager.create() should pick up user-specified class from conf") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key -> + classOf[CreateAtomicTestManager].getName) { + val fileManager = + CheckpointFileManager.create(new Path("/"), spark.sessionState.newHadoopConf) + assert(fileManager.isInstanceOf[CreateAtomicTestManager]) + } + } + + test("CheckpointFileManager.create() should fallback from FileContext to FileSystem") { + import CheckpointFileManagerSuiteFileSystem.scheme + spark.conf.set(s"fs.$scheme.impl", classOf[CheckpointFileManagerSuiteFileSystem].getName) + quietly { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) + + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) + } + } + } +} + +class FileContextBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileContextBasedCheckpointFileManager(path, new Configuration()) + } +} + +class FileSystemBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileSystemBasedCheckpointFileManager(path, new Configuration()) + } +} + + +/** A fake implementation to test different characteristics of CheckpointFileManager interface */ +class CreateAtomicTestManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + import CheckpointFileManager._ + + override def createAtomic(path: Path, overwrite: Boolean): CancellableFSDataOutputStream = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + } + val originalOut = super.createAtomic(path, overwrite) + + new CancellableFSDataOutputStream(originalOut) { + override def close(): Unit = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + throw new IOException("Copy failed intentionally") + } + super.close() + } + + override def cancel(): Unit = { + CreateAtomicTestManager.cancelCalledInCreateAtomic = true + originalOut.cancel() + } + } + } +} + +object CreateAtomicTestManager { + @volatile var shouldFailInCreateAtomic = false + @volatile var cancelCalledInCreateAtomic = false +} + + +/** + * CheckpointFileManagerSuiteFileSystem to test fallback of the CheckpointFileManager + * from FileContext to FileSystem API. + */ +private class CheckpointFileManagerSuiteFileSystem extends RawLocalFileSystem { + import CheckpointFileManagerSuiteFileSystem.scheme + + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +private object CheckpointFileManagerSuiteFileSystem { + val scheme = s"CheckpointFileManagerSuiteFileSystem${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 12eaf6341508..ec961a9ecb59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -22,15 +22,10 @@ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - import CompactibleFileStreamLog._ /** -- testing of `object CompactibleFileStreamLog` begins -- */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4677769c12a3..9268306ce427 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -17,46 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.io.{File, FileNotFoundException, IOException} -import java.net.URI +import java.io.File import java.util.ConcurrentModificationException import scala.language.implicitConversions -import scala.util.Random -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ -import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - private implicit def toOption[A](a: A): Option[A] = Option(a) - test("FileManager: FileContextManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileContextManager(path, new Configuration)) - } - } - - test("FileManager: FileSystemManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileSystemManager(path, new Configuration)) - } - } - test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir @@ -82,26 +58,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - spark.conf.set( - s"fs.$scheme.impl", - classOf[FakeFileSystem].getName) - withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog.add(0, "batch0")) - assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - - - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog2.get(0) === Some("batch0")) - assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) - - } - } - test("HDFSMetadataLog: purge") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -121,7 +77,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { // There should be exactly one file, called "2", in the metadata directory. // This check also tests for regressions of SPARK-17475 - val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + val allFiles = new File(metadataLog.metadataPath.toString).listFiles() + .filter(!_.getName.startsWith(".")).toSeq assert(allFiles.size == 1) assert(allFiles(0).getName() == "2") } @@ -172,7 +129,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("HDFSMetadataLog: metadata directory collision") { + testQuietly("HDFSMetadataLog: metadata directory collision") { withTempDir { temp => val waiter = new Waiter val maxBatchId = 100 @@ -206,60 +163,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - /** Basic test case for [[FileManager]] implementation. */ - private def testFileManager(basePath: Path, fm: FileManager): Unit = { - // Mkdirs - val dir = new Path(s"$basePath/dir/subdir/subsubdir") - assert(!fm.exists(dir)) - fm.mkdirs(dir) - assert(fm.exists(dir)) - fm.mkdirs(dir) - - // List - val acceptAllFilter = new PathFilter { - override def accept(path: Path): Boolean = true - } - val rejectAllFilter = new PathFilter { - override def accept(path: Path): Boolean = false - } - assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) - assert(fm.list(basePath, rejectAllFilter).length === 0) - - // Create - val path = new Path(s"$dir/file") - assert(!fm.exists(path)) - fm.create(path).close() - assert(fm.exists(path)) - intercept[IOException] { - fm.create(path) - } - - // Open and delete - fm.open(path).close() - fm.delete(path) - assert(!fm.exists(path)) - intercept[IOException] { - fm.open(path) - } - fm.delete(path) // should not throw exception - - // Rename - val path1 = new Path(s"$dir/file1") - val path2 = new Path(s"$dir/file2") - fm.create(path1).close() - assert(fm.exists(path1)) - fm.rename(path1, path2) - intercept[FileNotFoundException] { - fm.rename(path1, path2) - } - val path3 = new Path(s"$dir/file3") - fm.create(path3).close() - assert(fm.exists(path3)) - intercept[FileAlreadyExistsException] { - fm.rename(path2, path3) - } - } - test("verifyBatchIds") { import HDFSMetadataLog.verifyBatchIds verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) @@ -277,14 +180,3 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) } } - -/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ -class FakeFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } -} - -object FakeFileSystem { - val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9..3bc36ce55d90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 00d4f0b8503d..9be22d94b565 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -40,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new ContinuousMemoryWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index e11705a227f4..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} -import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport -import org.apache.spark.sql.streaming.StreamTest - -class RateSourceV2Suite extends StreamTest { - test("microbatch - numPartitions propagated") { - val reader = new RateStreamV2Reader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createReadTasks() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamV2Reader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamV2Reader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamV2Reader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[ContinuousRateStreamReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new ContinuousRateStreamReader( - new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setOffset(Optional.empty()) - val tasks = reader.createReadTasks() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamReadTask => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - ContinuousRateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala deleted file mode 100644 index ec1154907365..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io.{IOException, OutputStreamWriter} -import java.net.ServerSocket -import java.sql.Timestamp -import java.util.concurrent.LinkedBlockingQueue - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} - -class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { - import testImplicits._ - - override def afterEach() { - sqlContext.streams.active.foreach(_.stop()) - if (serverThread != null) { - serverThread.interrupt() - serverThread.join() - serverThread = null - } - if (source != null) { - source.stop() - source = null - } - } - - private var serverThread: ServerThread = null - private var source: Source = null - - test("basic usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("timestamped usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString, - "includeTimestamp" -> "true") - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("params not given") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map()) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) - } - } - - test("non-boolean includeTimestamp") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost", - "port" -> "1234", "includeTimestamp" -> "fasle")) - } - } - - test("user-specified schema given") { - val provider = new TextSocketSourceProvider - val userSpecifiedSchema = StructType( - StructField("name", StringType) :: - StructField("area", StringType) :: Nil) - val exception = intercept[AnalysisException] { - provider.sourceSchema( - sqlContext, Some(userSpecifiedSchema), - "", - Map("host" -> "localhost", "port" -> "1234")) - } - assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) - } - - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - source = provider.createSource(sqlContext, "", None, "", parameters) - } - } - - test("input row metrics") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) - } - source.stop() - source = null - } - } - - private class ServerThread extends Thread with Logging { - private val serverSocket = new ServerSocket(0) - private val messageQueue = new LinkedBlockingQueue[String]() - - val port = serverSocket.getLocalPort - - override def run(): Unit = { - try { - val clientSocket = serverSocket.accept() - clientSocket.setTcpNoDelay(true) - val out = new OutputStreamWriter(clientSocket.getOutputStream) - while (true) { - val line = messageQueue.take() - out.write(line + "\n") - out.flush() - } - } catch { - case e: InterruptedException => - } finally { - serverSocket.close() - } - } - - def enqueue(line: String): Unit = { - messageQueue.put(line) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala new file mode 100644 index 000000000000..55acf2ba28d2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.{StreamTest, Trigger} + +class ConsoleWriterSuite extends StreamTest { + import testImplicits._ + + test("microbatch - default") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("microbatch - with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("microbatch - truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala similarity index 67% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 9137d650e906..e60c339bc9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.concurrent.ConcurrentLinkedQueue @@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -46,49 +47,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter()) .start() - // -- batch 0 --------------------------------------- - input.addData(1, 2, 3, 4) - query.processAllAvailable() + def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { + import ForeachWriterSuite._ - var expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 1), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) - ) - var expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 2), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) - ) + val events = ForeachWriterSuite.allEvents() + assert(events.size === 2) // one seq of events for each of the 2 partitions - var allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + // Verify both seq of events have an Open event as the first event + assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) - ForeachSinkSuite.clear() + // Verify all the Process event correspond to the expected data + val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) + assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) + + // Verify both seq of events have a Close event as the last event + assert(events.map(_.last).toSet === Set(Close(None), Close(None))) + } + + // -- batch 0 --------------------------------------- + ForeachWriterSuite.clear() + input.addData(1, 2, 3, 4) + query.processAllAvailable() + verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- + ForeachWriterSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() - - expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 5), - ForeachSinkSuite.Process(value = 7), - ForeachSinkSuite.Close(None) - ) - expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 1), - ForeachSinkSuite.Process(value = 6), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) - ) - - allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + verifyOutput(expectedVersion = 1, expectedData = 5 to 8) query.stop() } @@ -110,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf input.addData(1, 2, 3, 4) query.processAllAvailable() - var allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() // -- batch 1 --------------------------------------- input.addData(5, 6, 7, 8) query.processAllAvailable() - allEvents = ForeachSinkSuite.allEvents() + allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) @@ -146,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("error") + throw new RuntimeException("ForeachSinkSuite error") } }).start() input.addData(1, 2, 3, 4) @@ -156,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") assert(query.isActive === false) - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") } } @@ -192,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf inputData.addData(10, 11, 12) query.processAllAvailable() - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) } finally { @@ -225,27 +211,29 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 3) + val allEvents = ForeachWriterSuite.allEvents() + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 3), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) ) assert(allEvents === expectedEvents) @@ -273,7 +261,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } /** A global object to collect events in the executor */ -object ForeachSinkSuite { +object ForeachWriterSuite { trait Event @@ -300,21 +288,21 @@ object ForeachSinkSuite { /** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) + events += ForeachWriterSuite.Open(partition = partitionId, version = version) true } override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) + events += ForeachWriterSuite.Process(value) } override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala similarity index 50% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 03d0f63fa4d7..ff14ec38e66a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -15,13 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources +import java.nio.file.Files +import java.util.Optional import java.util.concurrent.TimeUnit -import org.apache.spark.sql.AnalysisException +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock class RateSourceSuite extends StreamTest { @@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") } } - test("basic") { + test("microbatch - basic") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "10") @@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest { ) } - test("uniform distribution of event timestamps") { + test("microbatch - uniform distribution of event timestamps") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "1500") @@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + test("valueAtSecond") { - import RateStreamSource._ + import RateStreamProvider._ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) @@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest { option: String, value: String, expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { + val e = intercept[IllegalArgumentException] { spark.readStream .format("rate") .option(option, value) @@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest { .start() .awaitTermination() } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } } @@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest { assert(exception.getMessage.contains( "rate source does not support a user-specified schema")) } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala new file mode 100644 index 000000000000..a15a980bb92f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.IOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.ServerSocketChannel +import java.sql.Timestamp +import java.util.Optional +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } + } + + private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null + + case class AddSocketData(data: String*) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active socket source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + } + if (sources.isEmpty) { + throw new Exception( + "Could not find socket source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the socket source in the StreamExecution logical plan as there" + + "are multiple socket sources:\n\t" + sources.mkString("\n\t")) + } + val socketSource = sources.head + + assert(serverThread != null && serverThread.port != 0) + val currOffset = socketSource.getCurrentOffset() + data.foreach(serverThread.enqueue) + + val newOffset = LongOffset(currOffset.offset + data.size) + (socketSource, newOffset) + } + + override def toString: String = s"AddSocketData(data = $data)" + } + + test("backward compatibility with old path") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[TextSocketSourceProvider]) + case _ => + throw new IllegalStateException("Could not find socket source") + } + } + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + } + } + + test("timestamped usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val socket = spark + .readStream + .format("socket") + .options(Map( + "host" -> "localhost", + "port" -> serverThread.port.toString, + "includeTimestamp" -> "true")) + .load() + + assert(socket.schema === StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil)) + + var batch1Stamp: Timestamp = null + var batch2Stamp: Timestamp = null + + val curr = System.currentTimeMillis() + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "hello") + batch1Stamp = rows.head.getAs[Timestamp](1) + Thread.sleep(10) + }, + true), + AddSocketData("world"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "world") + batch2Stamp = rows.head.getAs[Timestamp](1) + }, + true), + StopStream + ) + + // Timestamp for rate stream is round to second which leads to milliseconds lost, that will + // make batch1stamp smaller than current timestamp if both of them are in the same second. + // Comparing by second to make sure the correct behavior. + assert(batch1Stamp.getTime >= curr / 1000 * 1000) + assert(!batch2Stamp.before(batch1Stamp)) + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) + } + } + + test("non-boolean includeTimestamp") { + val provider = new TextSocketSourceProvider + val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") + intercept[AnalysisException] { + val a = new DataSourceOptions(params.asJava) + provider.createMicroBatchReader(Optional.empty(), "", a) + } + } + + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val params = Map("host" -> "localhost", "port" -> "1234") + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + batchReader = provider.createMicroBatchReader( + Optional.empty(), "", new DataSourceOptions(parameters.asJava)) + } + } + + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AssertOnQuery { q => + val numRowMetric = + q.lastExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + numRowMetric.nonEmpty && numRowMetric.get.value == 1 + }, + StopStream + ) + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(0)) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocketChannel.socket().getLocalPort + + override def run(): Unit = { + try { + while (true) { + val clientSocketChannel = serverSocketChannel.accept() + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + + // Check whether remote client is closed but still send data to this closed socket. + // This happens in DataStreamReader where a source will be created to get the schema. + var remoteIsClosed = false + var cnt = 0 + while (cnt < 3 && !remoteIsClosed) { + if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { + cnt += 1 + Thread.sleep(100) + } else { + remoteIsClosed = true + } + } + + if (remoteIsClosed) { + logInfo(s"remote client ${clientSocketChannel.socket()} is closed") + } else { + while (true) { + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) + } + } + } + } catch { + case e: InterruptedException => + } finally { + serverSocketChannel.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c843b65020d8..73f870506040 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +27,17 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +137,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getData(provider, 19) === Set("a" -> 19)) } - test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") @@ -344,7 +343,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } - test("SPARK-18342: commit fails when rename fails") { + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() val conf = new Configuration() @@ -366,7 +365,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def numTempFiles: Int = { if (deltaFileDir.exists) { - deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + deltaFileDir.listFiles.map(_.getName).count(n => n.endsWith(".tmp")) } else 0 } @@ -471,6 +470,43 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("error writing [version].delta cancels the output stream") { + + val hadoopConf = new Configuration() + hadoopConf.set( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + classOf[CreateAtomicTestManager].getName) + val remoteDir = Utils.createTempDir().getAbsolutePath + + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -720,6 +756,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] * this provider */ def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] + + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } } object StateStoreTestsHelper { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 7d84f45d36be..02df45d1b798 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.status.ElementTrackingStore import org.apache.spark.status.config._ @@ -442,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -465,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -510,6 +516,50 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with } } + test("eviction should respect execution completion time") { + val conf = sparkContext.conf.clone().set(UI_RETAINED_EXECUTIONS.key, "2") + val store = new ElementTrackingStore(new InMemoryStore, conf) + val listener = new SQLAppStatusListener(conf, store, live = true) + val statusStore = new SQLAppStatusStore(store, Some(listener)) + + var time = 0 + val df = createTestDataFrame + // Start execution 1 and execution 2 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 1, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 2, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + + // Stop execution 2 before execution 1 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(2, time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(1, time)) + + // Start execution 3 and execution 2 should be evicted. + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 3, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + assert(statusStore.executionsCount === 2) + assert(statusStore.execution(2) === None) + } } @@ -517,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } @@ -566,6 +627,7 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { sc.listenerBus.waitUntilEmpty(10000) val statusStore = spark.sharedState.statusStore assert(statusStore.executionsCount() <= 50) + assert(statusStore.planGraphCount() <= 50) // No live data should be left behind after all executions end. assert(statusStore.listener.get.noLiveData()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a65..b55489cb2678 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { @@ -41,6 +42,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -68,6 +70,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -95,6 +98,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -122,6 +126,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -149,6 +154,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -176,6 +182,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -203,6 +210,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -231,6 +239,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -257,6 +266,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -299,6 +309,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val array0 = columnVector.getArray(0) @@ -321,6 +332,43 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("non nullable struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) + .createVector(allocator).asInstanceOf[NullableMapVector] + + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + vector.setValueCount(2) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(!columnVector.hasNull) + assert(columnVector.numNulls === 0) + + val row0 = columnVector.getStruct(0) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + columnVector.close() + allocator.close() + } + test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) @@ -359,23 +407,24 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) - val row2 = columnVector.getStruct(2, 2) + val row2 = columnVector.getStruct(2) assert(row2.isNullAt(0)) assert(row2.getLong(1) === 3L) assert(columnVector.isNullAt(3)) - val row4 = columnVector.getStruct(4, 2) + val row4 = columnVector.getStruct(4) assert(row4.getInt(0) === 5) assert(row4.getLong(1) === 5L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f..2d1ad4b45678 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -199,17 +199,17 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) testVectors("struct", 10, structType) { testVector => - val c1 = testVector.getChildColumn(0) - val c2 = testVector.getChildColumn(1) + val c1 = testVector.getChild(0) + val c2 = testVector.getChild(1) c1.putInt(0, 123) c2.putDouble(0, 3.45) c1.putInt(1, 456) c2.putDouble(1, 5.67) - assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0).get(0, IntegerType) === 123) + assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1).get(0, IntegerType) === 456) + assert(testVector.getStruct(1).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 38ea2e47fdef..8aeb06d42895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.vectorized import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -23,8 +23,6 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark @@ -268,17 +266,17 @@ object ColumnarBatchBenchmark { Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Java Array 177 / 181 1856.4 0.5 1.0X - ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X - ByteBuffer API 1411 / 1418 232.2 4.3 0.1X - DirectByteBuffer 467 / 474 701.8 1.4 0.4X - Unsafe Buffer 178 / 185 1843.6 0.5 1.0X - Column(on heap) 178 / 184 1840.8 0.5 1.0X - Column(off heap) 341 / 344 961.8 1.0 0.5X - Column(off heap direct) 178 / 184 1845.4 0.5 1.0X - UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X - UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X - Column On Heap Append 309 / 318 1059.1 0.9 0.6X + Java Array 177 / 183 1851.1 0.5 1.0X + ByteBuffer Unsafe 314 / 330 1043.7 1.0 0.6X + ByteBuffer API 1298 / 1307 252.4 4.0 0.1X + DirectByteBuffer 465 / 483 704.2 1.4 0.4X + Unsafe Buffer 179 / 183 1835.5 0.5 1.0X + Column(on heap) 181 / 186 1815.2 0.6 1.0X + Column(off heap) 344 / 349 951.7 1.1 0.5X + Column(off heap direct) 178 / 186 1838.6 0.5 1.0X + UnsafeRow (on heap) 388 / 394 844.8 1.2 0.5X + UnsafeRow (off heap) 400 / 403 819.4 1.2 0.4X + Column On Heap Append 315 / 325 1041.8 1.0 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -297,7 +295,7 @@ object ColumnarBatchBenchmark { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 - val benchmark = new Benchmark("Boolean Read/Write", iters * count) + val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong) benchmark.addCase("Bitset") { i: Int => { val b = new BitSet(count) var sum = 0L @@ -337,8 +335,8 @@ object ColumnarBatchBenchmark { Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Bitset 726 / 727 462.4 2.2 1.0X - Byte Array 530 / 542 632.7 1.6 1.4X + Bitset 741 / 747 452.6 2.2 1.0X + Byte Array 531 / 542 631.6 1.6 1.4X */ benchmark.run() } @@ -394,8 +392,8 @@ object ColumnarBatchBenchmark { String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap 332 / 338 49.3 20.3 1.0X - Off Heap 466 / 467 35.2 28.4 0.7X + On Heap 351 / 362 46.6 21.4 1.0X + Off Heap 456 / 466 35.9 27.8 0.8X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -434,7 +432,6 @@ object ColumnarBatchBenchmark { } def readArrays(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -448,7 +445,6 @@ object ColumnarBatchBenchmark { } def readArrayElements(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -479,10 +475,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 415 / 422 394.7 2.5 1.0X - Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X - On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X - Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + On Heap Read Size Only 426 / 437 384.9 2.6 1.0X + Off Heap Read Size Only 406 / 421 404.0 2.5 1.0X + On Heap Read Elements 2636 / 2642 62.2 16.1 0.2X + Off Heap Read Elements 3770 / 3774 43.5 23.0 0.1X */ benchmark.run } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d..f57f07b49826 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -65,22 +66,27 @@ class ColumnarBatchSuite extends SparkFunSuite { column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNull() reference += false + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNull() reference += true + assert(column.hasNull) assert(column.numNulls() == 1) column.appendNulls(3) (1 to 3).foreach(_ => reference += true) + assert(column.hasNull) assert(column.numNulls() == 4) idx = column.elementsAppended @@ -88,11 +94,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putNotNull(idx) reference += false idx += 1 + assert(column.hasNull) assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 + assert(column.hasNull) assert(column.numNulls() == 5) column.putNulls(idx, 3) @@ -100,6 +108,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += true reference += true idx += 3 + assert(column.hasNull) assert(column.numNulls() == 8) column.putNotNulls(idx, 4) @@ -108,6 +117,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 + assert(column.hasNull) assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => @@ -562,7 +572,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - testVector("String APIs", 6, StringType) { + testVector("String APIs", 7, StringType) { column => val reference = mutable.ArrayBuffer.empty[String] @@ -609,6 +619,10 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 1 assert(column.arrayData().elementsAppended == 17 + (s + s).length) + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + reference.zipWithIndex.foreach { v => val errMsg = "VectorType=" + column.getClass.getSimpleName assert(v._1.length == column.getArrayLength(v._2), errMsg) @@ -619,6 +633,40 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 0) } + testVector("CalendarInterval APIs", 4, CalendarIntervalType) { + column => + val reference = mutable.ArrayBuffer.empty[CalendarInterval] + + val months = column.getChild(0) + val microseconds = column.getChild(1) + assert(months.dataType() == IntegerType) + assert(microseconds.dataType() == LongType) + + months.putInt(0, 1) + microseconds.putLong(0, 100) + reference += new CalendarInterval(1, 100) + + months.putInt(1, 0) + microseconds.putLong(1, 2000) + reference += new CalendarInterval(0, 2000) + + column.putNull(2) + assert(column.getInterval(2) == null) + reference += null + + months.putInt(3, 20) + microseconds.putLong(3, 0) + reference += new CalendarInterval(20, 0) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getInterval(i), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { column => @@ -630,35 +678,38 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } - // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + // Populate it with arrays [0], [1, 2], null, [], [3, 4, 5] column.putArray(0, 0, 1) column.putArray(1, 1, 2) - column.putArray(2, 2, 0) - column.putArray(3, 3, 3) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getArray(0).numElements == 1) + assert(column.getArray(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getArray(2) == null) + assert(column.getArray(3).numElements == 0) + assert(column.getArray(4).numElements == 3) val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) - val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) - val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(4)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) - // Verify the ArrayData APIs - assert(column.getArray(0).numElements() == 1) + // Verify the ArrayData get APIs assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).numElements() == 0) - - assert(column.getArray(3).numElements() == 3) - assert(column.getArray(3).getInt(0) == 3) - assert(column.getArray(3).getInt(1) == 4) - assert(column.getArray(3).getInt(2) == 5) + assert(column.getArray(4).getInt(0) == 3) + assert(column.getArray(4).getInt(1) == 4) + assert(column.getArray(4).getInt(2) == 5) // Add a longer array which requires resizing column.reset() @@ -668,8 +719,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) - === array) + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } test("toArray for primitive types") { @@ -727,25 +777,70 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + test("Int Map") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = allocate(10, new MapType(IntegerType, IntegerType, false), memMode) + (0 to 1).foreach { colIndex => + val data = column.getChild(colIndex) + (0 to 5).foreach {i => + data.putInt(i, i * (colIndex + 1)) + } + } + + // Populate it with maps [0->0], [1->2, 2->4], null, [], [3->6, 4->8, 5->10] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putNull(2) + assert(column.getMap(2) == null) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getMap(0).numElements == 1) + assert(column.getMap(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getMap(3).numElements == 0) + assert(column.getMap(4).numElements == 3) + + val a1 = ColumnVectorUtils.toJavaIntMap(column.getMap(0)) + val a2 = ColumnVectorUtils.toJavaIntMap(column.getMap(1)) + val a4 = ColumnVectorUtils.toJavaIntMap(column.getMap(3)) + val a5 = ColumnVectorUtils.toJavaIntMap(column.getMap(4)) + + assert(a1.asScala == Map(0 -> 0)) + assert(a2.asScala == Map(1 -> 2, 2 -> 4)) + assert(a4.asScala == Map()) + assert(a5.asScala == Map(3 -> 6, 4 -> 8, 5 -> 10)) + + column.close() + } + } + testVector( "Struct Column", 10, new StructType().add("int", IntegerType).add("double", DoubleType)) { column => - val c1 = column.getChildColumn(0) - val c2 = column.getChildColumn(1) + val c1 = column.getChild(0) + val c2 = column.getChild(1) assert(c1.dataType() == IntegerType) assert(c2.dataType() == DoubleType) c1.putInt(0, 123) c2.putDouble(0, 3.45) - c1.putInt(1, 456) - c2.putDouble(1, 5.67) + + column.putNull(1) + assert(column.getStruct(1) == null) + + c1.putInt(2, 456) + c2.putDouble(2, 5.67) val s = column.getStruct(0) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) - val s2 = column.getStruct(1) + assert(column.isNullAt(1)) + assert(column.getStruct(1) == null) + + val s2 = column.getStruct(2) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) } @@ -786,8 +881,8 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new ArrayType(structType, true)) { column => val data = column.arrayData() - val c0 = data.getChildColumn(0) - val c1 = data.getChildColumn(1) + val c0 = data.getChild(0) + val c1 = data.getChild(1) // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) (0 until 6).foreach { i => c0.putInt(i, i) @@ -814,8 +909,8 @@ class ColumnarBatchSuite extends SparkFunSuite { new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true))) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) @@ -843,13 +938,13 @@ class ColumnarBatchSuite extends SparkFunSuite { "Nest Struct in Struct", 10, new StructType().add("int", IntegerType).add("struct", subSchema)) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) - val c1c0 = c1.getChildColumn(0) - val c1c1 = c1.getChildColumn(1) + val c1c0 = c1.getChild(0) + val c1c1 = c1.getChild(1) // Structs in c1: (7, 70), (8, 80), (9, 90) c1c0.putInt(0, 7) c1c0.putInt(1, 8) @@ -874,14 +969,13 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val capacity = 4 * 1024 val columns = schema.fields.map { field => allocate(capacity, field.dataType, memMode) } - val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) + val batch = new ColumnarBatch(columns.toArray) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] @@ -918,10 +1012,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) @@ -1155,7 +1246,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) - val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + val batch = new ColumnarBatch(columnVectors.toArray) batch.setNumRows(11) assert(batch.numCols() == 2) @@ -1178,4 +1269,75 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.close() allocator.close() } + + testVector("Decimal API", 4, DecimalType.IntDecimal) { + column => + + val reference = mutable.ArrayBuffer.empty[Decimal] + + var idx = 0 + column.putDecimal(idx, new Decimal().set(10), 10) + reference += new Decimal().set(10) + idx += 1 + + column.putDecimal(idx, new Decimal().set(20), 10) + reference += new Decimal().set(20) + idx += 1 + + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + reference += null + idx += 1 + + column.putDecimal(idx, new Decimal().set(30), 10) + reference += new Decimal().set(30) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getDecimal(i, 10, 0), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + + testVector("Binary APIs", 4, BinaryType) { + column => + + val reference = mutable.ArrayBuffer.empty[String] + var idx = 0 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + reference += "Hello" + idx += 1 + + column.putByteArray(idx, "World".getBytes(StandardCharsets.UTF_8)) + reference += "World" + idx += 1 + + column.putNull(idx) + reference += null + idx += 1 + + column.putByteArray(idx, "abc".getBytes(StandardCharsets.UTF_8)) + reference += "abc" + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + if (v != null) { + assert(v == new String(column.getBinary(i)), errMsg) + } else { + assert(column.isNullAt(i), errMsg) + assert(column.getBinary(i) == null, errMsg) + } + } + + column.close() + } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cb2df0ac54f4..5238adce4a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1168,4 +1168,26 @@ class JDBCSuite extends SparkFunSuite val df3 = sql("SELECT * FROM test_sessionInitStatement") assert(df3.collect() === Array(Row(21519, 1234))) } + + test("jdbc data source shouldn't have unnecessary metadata in its schema") { + val schema = StructType(Seq( + StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("DbTaBle", "TEST.PEOPLE") + .load() + assert(df.schema === schema) + + withTempView("people_view") { + sql( + s""" + |CREATE TEMPORARY VIEW people_view + |USING org.apache.spark.sql.jdbc + |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + assert(sql("select * from people_view").schema === schema) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 735e07c21373..e1022e377132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -33,6 +33,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { private var targetPartitionSchema: StructType = _ override def beforeAll(): Unit = { + super.beforeAll() targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) targetPartitionSchema = new StructType() .add("b", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f4594..fef01c860db6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 0fe33e87318a..27c983f270bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.TestUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils @@ -86,15 +87,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -106,7 +107,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -133,14 +134,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -148,7 +149,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 773d34dfaf9a..12779b46bfe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -126,4 +126,20 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA checkLoad(df2, "jsonTable2") } + + test("SPARK-23459: Improve error message when specified unknown column in partition columns") { + withTempDir { dir => + val path = dir.getCanonicalPath + val unknown = "unknownColumn" + val df = Seq(1L -> "a").toDF("i", "j") + val schemaCatalog = df.schema.catalogString + val e = intercept[AnalysisException] { + df.write + .format("parquet") + .partitionBy(unknown) + .save(path) + }.getMessage + assert(e.contains(s"Partition column `$unknown` not found in schema $schemaCatalog")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala similarity index 62% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 90d92864b26f..cfa69a86de1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -22,24 +22,24 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite /** - * A simple test suite to verify `DataSourceV2Options`. + * A simple test suite to verify `DataSourceOptions`. */ -class DataSourceV2OptionsSuite extends SparkFunSuite { +class DataSourceOptionsSuite extends SparkFunSuite { test("key is case-insensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("foo" -> "bar").asJava) assert(options.get("foo").get() == "bar") assert(options.get("FoO").get() == "bar") assert(!options.get("abc").isPresent) } test("value is case-sensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + val options = new DataSourceOptions(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } test("getInt") { - val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava) assert(options.getInt("numFOO", 10) == 1) assert(options.getInt("numFOO2", 10) == 10) @@ -49,7 +49,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getBoolean") { - val options = new DataSourceV2Options( + val options = new DataSourceOptions( Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) assert(options.getBoolean("isFoo", false)) assert(!options.getBoolean("isFoo2", true)) @@ -59,7 +59,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getLong") { - val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807", "foo" -> "bar").asJava) assert(options.getLong("numFOO", 0L) == 9223372036854775807L) assert(options.getLong("numFoo2", -1L) == -1L) @@ -70,7 +70,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getDouble") { - val options = new DataSourceV2Options(Map("numFoo" -> "922337.1", + val options = new DataSourceOptions(Map("numFoo" -> "922337.1", "foo" -> "bar").asJava) assert(options.getDouble("numFOO", 0d) == 922337.1d) assert(options.getDouble("numFoo2", -1.02d) == -1.02d) @@ -79,4 +79,29 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } + + test("standard options") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.TABLE_KEY -> "tbl").asJava) + + assert(options.paths().toSeq == Seq("abc")) + assert(options.tableName().get() == "tbl") + assert(!options.databaseName().isPresent) + } + + test("standard options with both singular path and multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) + + assert(options.paths().toSeq == Seq("abc", "c", "d")) + } + + test("standard options with only multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) + + assert(options.paths().toSeq == Seq("c", "d\"e")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e4984bd1..e0a53272cd22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -22,12 +22,18 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -44,19 +50,77 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + } + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) - checkAnswer(df.select('i).filter('i > 10), Nil) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } else { + val reader = getJavaReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } + + val q3 = df.select('i).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } else { + val reader = getJavaReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } + + val q4 = df.select('j).filter('j < -10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } } } } - test("unsafe row implementation") { + test("unsafe row scan implementation") { Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -67,11 +131,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("columnar batch scan implementation") { + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 90).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + } + } + } + test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("A schema needs to be specified")) + assert(e.message.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() @@ -82,6 +157,45 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("partitioning reporting") { + import org.apache.spark.sql.functions.{count, sum} + Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) + + val groupByColA = df.groupBy('a).agg(sum('b)) + checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) + assert(groupByColA.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) + assert(groupByColAB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColB = df.groupBy('b).agg(sum('a)) + checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(groupByColB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + } + } + } + + test("SPARK-23574: no shuffle exchange with single partition") { + val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -149,25 +263,116 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple counter in writer with onDataWriterCommit") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + val numPartition = 6 + spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + assert(SimpleCounter.getCounter == numPartition, + "method onDataWriterCommit should be called as many as the number of partitions") + } + } + } + + test("SPARK-23293: data source v2 self join") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + val df2 = df.select(($"i" + 1).as("k"), $"j") + checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) + } + + test("SPARK-23301: column pruning with arbitrary expressions") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + + val q1 = df.select('i + 1) + checkAnswer(q1, (1 until 11).map(i => Row(i))) + val reader1 = getReader(q1) + assert(reader1.requiredSchema.fieldNames === Seq("i")) + + val q2 = df.select(lit(1)) + checkAnswer(q2, (0 until 10).map(i => Row(1))) + val reader2 = getReader(q2) + assert(reader2.requiredSchema.isEmpty) + + // 'j === 1 can't be pushed down, but we should still be able do column pruning + val q3 = df.filter('j === -1).select('j * 2) + checkAnswer(q3, Row(-2)) + val reader3 = getReader(q3) + assert(reader3.filters.isEmpty) + assert(reader3.requiredSchema.fieldNames === Seq("j")) + + // column pruning should work with other operators. + val q4 = df.sort('i).limit(1).select('i + 1) + checkAnswer(q4, Row(1)) + val reader4 = getReader(q4) + assert(reader4.requiredSchema.fieldNames === Seq("i")) + } + + test("SPARK-23315: get output from canonicalized data source v2 related plans") { + def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + assert(logical.canonicalized.output.length == numOutput) + + val physical = df.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + }.head + assert(physical.canonicalized.output.length == numOutput) + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + checkCanonicalizedOutput(df, 2) + checkCanonicalizedOutput(df.select('i), 1) + } +} + +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { - java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { +class SimpleDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[Row] + with DataReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) + override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) override def next(): Boolean = { current += 1 @@ -183,7 +388,7 @@ class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader + class Reader extends DataSourceReader with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -194,8 +399,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - Array.empty + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } override def pushedFilters(): Array[Filter] = filters @@ -204,37 +413,37 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[ReadTask[Row]] + val res = new ArrayList[DataReaderFactory[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedReadTask(0, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) } res } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) - extends ReadTask[Row] with DataReader[Row] { +class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) + extends DataReaderFactory[Row] with DataReader[Row] { private var current = start - 1 override def createDataReader(): DataReader[Row] = { - new AdvancedReadTask(start, end, requiredSchema) + new AdvancedDataReaderFactory(start, end, requiredSchema) } override def close(): Unit = {} @@ -256,26 +465,27 @@ class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), + new UnsafeRowDataReaderFactory(5, 10)) } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowReadTask(start: Int, end: Int) - extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -292,11 +502,109 @@ class UnsafeRowReadTask(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceV2Reader { - override def createReadTasks(): JList[ReadTask[Row]] = + class Reader(val readSchema: StructType) extends DataSourceReader { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } - override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = new Reader(schema) } + +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +} + +class BatchDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { + + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) + + private var current = start + + override def createDataReader(): DataReader[ColumnarBatch] = this + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } + + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() +} + +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") + + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + // Note that we don't have same value of column `a` across partitions. + java.util.Arrays.asList( + new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + class MyPartitioning extends Partitioning { + override def numPartitions(): Int = 2 + + override def satisfy(distribution: Distribution): Boolean = distribution match { + case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case _ => false + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +} + +class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) + extends DataReaderFactory[Row] + with DataReader[Row] { + assert(i.length == j.length) + + private var current = -1 + + override def createDataReader(): DataReader[Row] = this + + override def next(): Boolean = { + current += 1 + current < i.length + } + + override def get(): Row = Row(i(current), j(current)) + + override def close(): Unit = {} +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index cd7252eb2e3d..a5007fa32135 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,10 +42,10 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,7 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + new SimpleCSVDataReaderFactory( + f.getPath.toUri.toString, + serializableConf): DataReaderFactory[Row] }.toList.asJava } else { Collections.emptyList() @@ -62,11 +64,16 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { + SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } + override def onDataWriterCommit(message: WriterCommitMessage): Unit = { + SimpleCounter.increaseCounter + } + override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) @@ -102,7 +109,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration new Reader(path.toUri.toString, conf) @@ -112,7 +119,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS jobId: String, schema: StructType, mode: SaveMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + options: DataSourceOptions): Optional[DataSourceWriter] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -139,7 +146,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString if (internal) { new InternalRowWriter(jobId, pathStr, conf) @@ -149,8 +156,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) - extends ReadTask[Row] with DataReader[Row] { +class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) + extends DataReaderFactory[Row] with DataReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @@ -181,10 +188,29 @@ class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) } } +private[v2] object SimpleCounter { + private var count: Int = 0 + + def increaseCounter: Unit = { + count += 1 + } + + def getCounter: Int = { + count + } + + def resetCounter: Unit = { + count = 0 + } +} + class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -217,7 +243,10 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce0737..7e8fde1ff8e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -137,20 +138,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.get("watermark") === formatTimestamp(5)) }, AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), CheckAnswer((10, 3)), assertEventStats { e => assert(e.get("max") === formatTimestamp(25)) assert(e.get("min") === formatTimestamp(25)) assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(5)) } ) } @@ -167,15 +160,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +183,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +241,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +380,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,8 +458,35 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 8c4e1fd00b0a..ed53def556cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.hadoop.fs.Path +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -33,6 +34,19 @@ import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + test("unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() @@ -265,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { @@ -340,7 +351,7 @@ class FileStreamSinkSuite extends StreamTest { } test("FileStreamSink.ancestorIsMetadataDirectory()") { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() def assertAncestorIsMetadataDirectory(path: String): Unit = assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) def assertAncestorIsNotMetadataDirectory(path: String): Unit = @@ -392,4 +403,55 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("SPARK-23288 writing and checking output metrics") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + + var query: StreamingQuery = null + + var numTasks = 0 + var recordsWritten: Long = 0L + var bytesWritten: Long = 0L + try { + spark.sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val outputMetrics = taskEnd.taskMetrics.outputMetrics + recordsWritten += outputMetrics.recordsWritten + bytesWritten += outputMetrics.bytesWritten + numTasks += 1 + } + }) + + query = + df.writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + + assert(numTasks > 0) + assert(recordsWritten === 5) + // This is heavily file type/version specific but should be filled + assert(bytesWritten > 0) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 39bb57274061..d4bd9c7987f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -74,11 +74,11 @@ abstract class FileStreamSourceTest protected def addData(source: FileStreamSource): Unit } - case class AddTextFileData(content: String, src: File, tmp: File) + case class AddTextFileData(content: String, src: File, tmp: File, tmpFilePrefix: String = "text") extends AddFileData { override def addData(source: FileStreamSource): Unit = { - val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val tempFile = Utils.tempFileWith(new File(tmp, tmpFilePrefix)) val finalFile = new File(src, tempFile.getName) src.mkdirs() require(stringToFile(tempFile, content).renameTo(finalFile)) @@ -207,6 +207,19 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + // ============= Basic parameter exists tests ================ test("FileStreamSource schema: no path") { @@ -408,6 +421,52 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-21996 read from text files -- file name has space") { + withTempDirs { case (src, tmp) => + val textStream = createFileStream("text", src.getCanonicalPath) + val filtered = textStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp, "text text"), + CheckAnswer("keep2", "keep3") + ) + } + } + + test("SPARK-21996 read from text files generated by file sink -- file name has space") { + val testTableName = "FileStreamSourceTest" + withTable(testTableName) { + withTempDirs { case (src, checkpoint) => + val output = new File(src, "text text") + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val query = ds.writeStream + .option("checkpointLocation", checkpoint.getCanonicalPath) + .format("text") + .start(output.getCanonicalPath) + + try { + inputData.addData("foo") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + query.stop() + } + + val df2 = spark.readStream.format("text").load(output.getCanonicalPath) + val query2 = df2.writeStream.format("memory").queryName(testTableName).start() + try { + query2.processAllAvailable() + checkDatasetUnorderly(spark.table(testTableName).as[String], "foo") + } finally { + query2.stop() + } + } + } + } + test("read from textfile") { withTempDirs { case (src, tmp) => val textStream = spark.readStream.textFile(src.getCanonicalPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index de2b51678cea..b1416bff87ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -42,8 +42,7 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { + with BeforeAndAfterAll { import testImplicits._ import GroupStateImpl._ @@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( - sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca..e45f9d3e2e97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,58 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala deleted file mode 100644 index 45142278993b..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.streaming._ - -trait StatefulOperatorTest { - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - colNames: Seq[String]): Boolean = { - val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output - val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions - val groupingAttr = attr.filter(a => colNames.contains(a.name)) - checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) - } - - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - expectedPartitioning: Partitioning): Boolean = { - val operator = sq.asInstanceOf[StreamExecution].lastExecution - .executedPlan.collect { case p: T => p } - operator.head.children.forall( - _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c65e5d3dd75c..c1ec1eba69fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,20 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4b7f0fbe97d4..9d139a927bea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -37,11 +37,14 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -80,6 +83,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -94,10 +100,23 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * been processed. */ object AddData { - def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] = AddDataMemory(source, data) } + /** + * Adds data to multiple memory streams such that all the data will be made visible in the + * same batch. This is applicable only to MicroBatchExecution, as this coordination cannot be + * performed at the driver in ContinuousExecutions. + */ + object MultiAddData { + def apply[A] + (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = { + val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2)) + StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]")) + } + } + /** A trait that can be extended when testing a source. */ trait AddData extends StreamAction { /** @@ -105,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (Source, Offset) + def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) } /** A trait that can be extended when testing a source. */ @@ -113,10 +132,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def runAction(): Unit } - case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { (source, source.addData(data)) } } @@ -174,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( @@ -184,12 +203,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + + private def operatorName = "CheckNewAnswer" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -213,6 +249,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]" } + /** + * Performs multiple actions while locking the stream from progressing. + * This is applicable only to MicroBatchExecution, as progress of ContinuousExecution + * cannot be controlled from the driver. + */ + case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: String = null) + extends StreamAction { + + override def toString(): String = { + if (desc != null) desc else super.toString + } + } + /** Assert that a body is true */ class Assert(condition: => Boolean, val message: String = "") extends StreamAction { def run(): Unit = { Assertions.assert(condition) } @@ -259,7 +308,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(): AssertOnQuery = Execute { case s: ContinuousExecution => - val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get) .askSync[Long](IncrementAndGetEpoch) s.awaitEpoch(newEpoch - 1) case _ => throw new IllegalStateException("microbatch cannot increment epoch") @@ -276,7 +325,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -291,6 +340,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + var manualClockExpectedTime = -1L @volatile var streamThreadDeathCause: Throwable = null @@ -401,254 +453,316 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) + val lastExecution = currentStream.lastExecution + if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { + // Verify if stateful operators have correct metadata and distribution + // This can often catch hard to debug errors when developing stateful operators + lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s => + assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores)) + s.requiredChildDistribution.foreach { d => + withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") { + assert(d.requiredNumPartitions.isDefined) + assert(d.requiredNumPartitions.get >= 1) + if (d != AllTuples) { + assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions) + } + } + } + } } - try if (lastOnly) latestBatchData else allData catch { + + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } - var manualClockExpectedTime = -1L - val defaultCheckpointLocation = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - try { - startedTest.foreach { action => - logInfo(s"Processing test stream action: $action") - action match { - case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") - verify(triggerClock.isInstanceOf[SystemClock] - || triggerClock.isInstanceOf[StreamManualClock], - "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + def executeAction(action: StreamAction): Unit = { + logInfo(s"Processing test stream action: $action") + action match { + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => + verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + + additionalConfs.foreach(pair => { + val value = + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None + resetConfValues(pair._1) = value + sparkSession.conf.set(pair._1, pair._2) + }) + + lastStream = currentStream + currentStream = + sparkSession + .streams + .startQuery( + None, + Some(metadataRoot), + stream, + Map(), + sink, + outputMode, + trigger = trigger, + triggerClock = triggerClock) + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + // Wait until the initialization finishes, because some tests need to use `logicalPlan` + // after starting the query. + try { + currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + assert(s.lastExecution != null) + } + case _ => } - val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + } catch { + case _: StreamingQueryException => + // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. + } - additionalConfs.foreach(pair => { - val value = - if (sparkSession.conf.contains(pair._1)) { - Some(sparkSession.conf.get(pair._1)) - } else None - resetConfValues(pair._1) = value - sparkSession.conf.set(pair._1, pair._2) - }) + case AdvanceManualClock(timeToAdd) => + verify(currentStream != null, + "can not advance manual clock when a stream is not running") + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], + s"can not advance clock of type ${currentStream.triggerClock.getClass}") + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) + + // Make sure we don't advance ManualClock too early. See SPARK-16002. + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + } + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") + + case StopStream => + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.queryExecutionThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest( + "Timed out while stopping and waiting for microbatchthread to terminate.", e) + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { lastStream = currentStream - currentStream = - sparkSession - .streams - .startQuery( - None, - Some(metadataRoot), - stream, - Map(), - sink, - outputMode, - trigger = trigger, - triggerClock = triggerClock) - .asInstanceOf[StreamingQueryWrapper] - .streamingQuery - // Wait until the initialization finishes, because some tests need to use `logicalPlan` - // after starting the query. - try { - currentStream.awaitInitialization(streamingTimeout.toMillis) - } catch { - case _: StreamingQueryException => - // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. - } + currentStream = null + } - case AdvanceManualClock(timeToAdd) => - verify(currentStream != null, - "can not advance manual clock when a stream is not running") - verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], - s"can not advance clock of type ${currentStream.triggerClock.getClass}") - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - assert(manualClockExpectedTime >= 0) - - // Make sure we don't advance ManualClock too early. See SPARK-16002. - eventually("StreamManualClock has not yet entered the waiting state") { - assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[StreamingQueryException] { + currentStream.awaitTermination() } - - clock.advance(timeToAdd) - manualClockExpectedTime += timeToAdd - verify(clock.getTimeMillis() === manualClockExpectedTime, - s"Unexpected clock time after updating: " + - s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") - - case StopStream => - verify(currentStream != null, "can not stop a stream that is not running") - try failAfter(streamingTimeout) { - currentStream.stop() - verify(!currentStream.queryExecutionThread.isAlive, - s"microbatch thread not stopped") - verify(!currentStream.isActive, - "query.isActive() is false even after stopping") - verify(currentStream.exception.isEmpty, - s"query.exception() is not empty after clean stop: " + - currentStream.exception.map(_.toString()).getOrElse("")) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest( - "Timed out while stopping and waiting for microbatchthread to terminate.", e) - case t: Throwable => - failTest("Error while stopping stream", t) - } finally { - lastStream = currentStream - currentStream = null + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.queryExecutionThread.isAlive) } + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + if (ef.isFatalError) { + // This is a fatal error, `streamThreadDeathCause` should be set to this error in + // UncaughtExceptionHandler. + verify(streamThreadDeathCause != null && + streamThreadDeathCause.getClass === ef.causeClass, + "UncaughtExceptionHandler didn't receive the correct error\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") + streamThreadDeathCause = null + } + ef.assertFailure(exception.getCause) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure", e) + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + } - case ef: ExpectFailure[_] => - verify(currentStream != null, "can not expect failure when stream is not running") - try failAfter(streamingTimeout) { - val thrownException = intercept[StreamingQueryException] { - currentStream.awaitTermination() - } - eventually("microbatch thread not stopped after termination with failure") { - assert(!currentStream.queryExecutionThread.isAlive) + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when no stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") + + case a: AddData => + try { + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } } - verify(currentStream.exception === Some(thrownException), - s"incorrect exception returned by query.exception()") - - val exception = currentStream.exception.get - verify(exception.cause.getClass === ef.causeClass, - "incorrect cause in exception returned by query.exception()\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") - if (ef.isFatalError) { - // This is a fatal error, `streamThreadDeathCause` should be set to this error in - // UncaughtExceptionHandler. - verify(streamThreadDeathCause != null && - streamThreadDeathCause.getClass === ef.causeClass, - "UncaughtExceptionHandler didn't receive the correct error\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") - streamThreadDeathCause = null + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") } - ef.assertFailure(exception.getCause) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while waiting for failure", e) - case t: Throwable => - failTest("Error while checking stream failure", t) - } finally { - lastStream = currentStream - currentStream = null } - - case a: AssertOnQuery => - verify(currentStream != null || lastStream != null, - "cannot assert when no stream has been started") - val streamToAssert = Option(currentStream).getOrElse(lastStream) - try { - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") - } catch { - case NonFatal(e) => - failTest(s"Assert on query failed: ${a.message}", e) + // Add data + val queryToUse = Option(currentStream).orElse(Option(lastStream)) + val (source, offset) = a.addData(queryToUse) + + def findSourceIndex(plan: LogicalPlan): Option[Int] = { + plan + .collect { + case r: StreamingExecutionRelation => r.source + case r: StreamingDataSourceV2Relation => r.reader + } + .zipWithIndex + .find(_._1 == source) + .map(_._2) } - case a: Assert => - val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify({ a.run(); true }, s"Assert failed: ${a.message}") - - case a: AddData => - try { - - // If the query is running with manual clock, then wait for the stream execution - // thread to start waiting for the clock to increment. This is needed so that we - // are adding data when there is no trigger that is active. This would ensure that - // the data gets deterministically added to the next batch triggered after the manual - // clock is incremented in following AdvanceManualClock. This avoid race conditions - // between the test thread and the stream execution thread in tests using manual - // clock. - if (currentStream != null && - currentStream.triggerClock.isInstanceOf[StreamManualClock]) { - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - eventually("Error while synchronizing with manual clock before adding data") { - if (currentStream.isActive) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (!currentStream.isActive) { - failTest("Query terminated while synchronizing with manual clock") + // Try to find the index of the source to which data was added. Either get the index + // from the current active query or the original input logical plan. + val sourceIndex = + queryToUse.flatMap { query => + findSourceIndex(query.logicalPlan) + }.orElse { + findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) } - } - // Add data - val queryToUse = Option(currentStream).orElse(Option(lastStream)) - val (source, offset) = a.addData(queryToUse) - - def findSourceIndex(plan: LogicalPlan): Option[Int] = { - plan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .find(_._1 == source) - .map(_._2) + }.getOrElse { + throw new IllegalArgumentException( + "Could not find index of the source to which data was added") } - // Try to find the index of the source to which data was added. Either get the index - // from the current active query or the original input logical plan. - val sourceIndex = - queryToUse.flatMap { query => - findSourceIndex(query.logicalPlan) - }.orElse { - findSourceIndex(stream.logicalPlan) - }.getOrElse { - throw new IllegalArgumentException( - "Could find index of the source to which data was added") - } + // Store the expected offset of added data to wait for it later + awaiting.put(sourceIndex, offset) + } catch { + case NonFatal(e) => + failTest("Error adding data", e) + } - // Store the expected offset of added data to wait for it later - awaiting.put(sourceIndex, offset) - } catch { - case NonFatal(e) => - failTest("Error adding data", e) - } + case e: ExternalAction => + e.runAction() - case e: ExternalAction => - e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { + error => failTest(error) + } - case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { - error => failTest(error) - } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } - case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { - error => failTest(error) - } + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) + } - case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - try { - globalCheckFunction(sparkAnswer) - } catch { - case e: Throwable => failTest(e.toString) - } - } - pos += 1 + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } + } + pos += 1 + } + + try { + startedTest.foreach { + case StreamProgressLockedActions(actns, _) => + // Perform actions while holding the stream from progressing + assert(currentStream != null, + s"Cannot perform stream-progress-locked actions $actns when query is not active") + assert(currentStream.isInstanceOf[MicroBatchExecution], + s"Cannot perform stream-progress-locked actions on non-microbatch queries") + currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { + actns.foreach(executeAction) + } + + case action: StreamAction => executeAction(action) } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97e065193fd0..382da1343078 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions with StatefulOperatorTest { + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), @@ -538,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index caf2bab8a585..42ffd472eb84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -44,8 +43,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -63,8 +60,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -103,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -147,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -266,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -285,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 54eb863dacc8..da8f9608c1e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ @@ -190,7 +192,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 5, 11)), AddData(rightInput, (1, 10)), CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), @@ -274,14 +276,14 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), CheckLastBatch(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + assertNumStateRows(total = 12, updated = 5), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) @@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } + test("stream stream self join") { + val input = MemoryStream[Int] + val df = input.toDF + val join = + df.select('value % 5 as "key", 'value).join( + df.select('value % 5 as "key", 'value), "key") + + testStream(join)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + } + test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ @@ -439,15 +462,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -470,15 +491,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with value <= 7 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) @@ -501,15 +520,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with value <= 4 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) @@ -532,15 +549,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -552,15 +567,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 3, updated = 1) @@ -572,15 +585,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), assertNumStateRows(total = 3, updated = 1) @@ -619,7 +630,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 1, 5, 10)), AddData(rightInput, (1, 11)), CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), @@ -653,36 +664,30 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added AddData(rightInput, 20), CheckLastBatch( Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows - AddData(leftInput, 40, 41), - AddData(rightInput, 40, 41), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - AddData(leftInput, 70), - AddData(rightInput, 71), + MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + assertNumStateRows(total = 6, updated = 6), // all inputs added since last check AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), // rightValue between 300 and 1000 should generate outer join rows even though it matches left - AddData(leftInput, 101, 102, 103), - AddData(rightInput, 101, 102, 103), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), CheckLastBatch(), - AddData(leftInput, 1000), - AddData(rightInput, 1001), + MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added AddData(rightInput, 1000), CheckLastBatch( Row(1000, 1010, 2000, 3000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288f..b96f2bcbdd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol @@ -174,6 +175,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("continuous processing listeners should receive QueryTerminatedEvent") { + val df = spark.readStream.format("rate").load() + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append, useV2Sink = true)( + StartStream(Trigger.Continuous(1000)), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() @@ -273,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getOffset: Option[Offset] = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.getOffset + super.getEndOffset } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2fa4595dab37..0cb2375e0a49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,25 +17,28 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils -import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ManualClock class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { @@ -206,19 +209,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // getOffset should take 50 ms the first time it is called - override def getOffset: Option[Offset] = { - val offset = super.getOffset - if (offset.nonEmpty) { - clock.waitTillTime(1050) + + private def dataAdded: Boolean = currentOffset.offset != -1 + + // setOffsetRange should take 50 ms the first time it is called after data is added + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.setOffsetRange(start, end) } - offset + } + + // getEndOffset should take 100 ms the first time it is called after data is added + override def getEndOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1150) + super.getEndOffset() } // getBatch should take 100 ms the first time it is called - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - if (start.isEmpty) clock.waitTillTime(1150) - super.getBatch(start, end) + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + clock.waitTillTime(1350) + super.createUnsafeRowReaderFactories() + } } } @@ -258,39 +271,44 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while offset is being fetched + // Test status and progress when setOffsetRange is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being fetched - AdvanceManualClock(50), // time = 1050 to unblock getOffset + AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 + AssertOnQuery(_.status.isDataAvailable === false), + AssertOnQuery(_.status.isTriggerActive === true), + AssertOnQuery(_.status.message.startsWith("Getting offsets from")), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), + + AdvanceManualClock(100), // time = 1150 to unblock getEndOffset + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being processed - AdvanceManualClock(100), // time = 1150 to unblock getBatch - AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 + AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AssertClockTime(1350), + AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, - AdvanceManualClock(350), // time = 1500 to unblock job + AdvanceManualClock(150), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadIsWaitingForTime(2000), + AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -307,15 +325,16 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 50) - assert(progress.durationMs.get("getBatch") === 100) - assert(progress.durationMs.get("queryPlanning") === 0) + assert(progress.durationMs.get("setOffsetRange") === 50) + assert(progress.durationMs.get("getEndOffset") === 100) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) + assert(progress.sources(0).startOffset === "0") assert(progress.sources(0).endOffset !== null) assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms @@ -424,7 +443,40 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("input row calculation with mixed batch and streaming sources") { + test("SPARK-22975: MetricsReporter defaults when there was no progress reported") { + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + + val gauges = sq.streamMetrics.metricRegistry.getGauges + assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + sq.stop() + } + } + } + + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -437,7 +489,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -450,6 +502,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -509,22 +676,22 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi .start() } - val input = MemoryStream[Int] - val q1 = startQuery(input.toDS, "stream_serializable_test_1") - val q2 = startQuery(input.toDS.map { i => + val input = MemoryStream[Int] :: MemoryStream[Int] :: MemoryStream[Int] :: Nil + val q1 = startQuery(input(0).toDS, "stream_serializable_test_1") + val q2 = startQuery(input(1).toDS.map { i => // Emulate that `StreamingQuery` get captured with normal usage unintentionally. // It should not fail the query. q1 i }, "stream_serializable_test_2") - val q3 = startQuery(input.toDS.map { i => + val q3 = startQuery(input(2).toDS.map { i => // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear // error message. q1.explain() i }, "stream_serializable_test_3") try { - input.addData(1) + input.foreach(_.addData(1)) // q2 should not fail since it doesn't use `q1` in the closure q2.processAllAvailable() @@ -691,6 +858,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 000000000000..e755625d09e0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamWriter], + mock[ContinuousReader], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val factory = new DataReaderFactory[UnsafeRow] { + override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + } + val reader = new ContinuousQueuedDataReader( + factory, + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index eda0d8ad4831..cd1704ac2fda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,36 +17,16 @@ package org.apache.spark.sql.streaming.continuous -import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} -import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} - -import scala.reflect.ClassTag -import scala.util.control.ControlThrowable - -import com.google.common.util.concurrent.UncheckedExecutionException -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.{SparkContext, SparkEnv} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes -import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.TestSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class ContinuousSuiteBase extends StreamTest { // We need more than the default local[2] to be able to schedule all partitions simultaneously. @@ -61,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 @@ -74,102 +54,71 @@ class ContinuousSuiteBase extends StreamTest { // A continuous trigger that will only fire the initial time for the duration of a test. // This allows clean testing with manual epoch advancement. protected val longContinuousTrigger = Trigger.Continuous("1 hour") + + override protected val defaultTrigger = Trigger.Continuous(100) + override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { import testImplicits._ - test("basic rate source") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + test("basic") { + val input = ContinuousMemoryStream[Int] - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(input.toDF())( + AddData(input, 0, 1, 2), + CheckAnswer(0, 1, 2), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), - StopStream) + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(0, 1, 2, 3, 4, 5)) } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -177,48 +126,83 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( "Continuous processing does not support current time operations.")) } + test("subquery alias") { + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") + + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) + } + test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } + test("task failure kills the query") { + val input = ContinuousMemoryStream[Int] + val df = input.toDF() + + // Get an arbitrary task from this query to kill. It doesn't matter which one. + var taskId: Long = -1 + val listener = new SparkListener() { + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + taskId = start.taskInfo.taskId + } + } + spark.sparkContext.addSparkListener(listener) + try { + testStream(df)( + StartStream(Trigger.Continuous(100)), + AddData(input, 0, 1, 2, 3), + Execute { _ => + // Wait until a task is started, then kill its first attempt. + eventually(timeout(streamingTimeout)) { + assert(taskId != -1) + } + spark.sparkContext.killTaskAttempt(taskId) + }, + ExpectFailure[SparkException] { e => + e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException] + }) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + test("query without test harness") { val df = spark.readStream .format("rate") @@ -226,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") @@ -258,13 +243,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), - Execute { query => - val data = query.sink.asInstanceOf[MemorySinkV2].allData - val vals = data.map(_.getLong(0)).toSet - assert(scala.Range(0, 25000).forall { i => - vals.contains(i) - }) - }) + StopStream, + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))) + ) } test("automatic epoch advancement") { @@ -280,6 +261,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } @@ -311,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { StopStream, StartStream(Trigger.Continuous(2012)), AwaitEpoch(50), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala new file mode 100644 index 000000000000..99e30561f81d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.mockito.InOrder +import org.mockito.Matchers.{any, eq => eqTo} +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.test.TestSparkSession + +class EpochCoordinatorSuite + extends SparkFunSuite + with LocalSparkSession + with MockitoSugar + with BeforeAndAfterEach { + + private var epochCoordinator: RpcEndpointRef = _ + + private var writer: StreamWriter = _ + private var query: ContinuousExecution = _ + private var orderVerifier: InOrder = _ + + override def beforeEach(): Unit = { + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] + query = mock[ContinuousExecution] + orderVerifier = inOrder(writer, query) + + spark = new TestSparkSession() + + epochCoordinator + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + } + + test("single epoch") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + // Here and in subsequent tests this is called to make a synchronous call to EpochCoordinator + // so that mocks would have been acted upon by the time verification happens + makeSynchronousCall() + + verifyCommit(1) + } + + test("single epoch, all but one writer partition has committed") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("single epoch, all but one reader partition has reported an offset") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("consequent epochs, messages for epoch (k + 1) arrive after messages for epoch k") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + // Message that arrives late + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 5) + reportPartitionOffset(0, 5) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) + } + + private def setWriterPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) + } + + private def setReaderPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetReaderPartitions(numPartitions)) + } + + private def commitPartitionEpoch(partitionId: Int, epoch: Long): Unit = { + val dummyMessage: WriterCommitMessage = mock[WriterCommitMessage] + epochCoordinator.send(CommitPartitionEpoch(partitionId, epoch, dummyMessage)) + } + + private def reportPartitionOffset(partitionId: Int, epoch: Long): Unit = { + val dummyOffset: PartitionOffset = mock[PartitionOffset] + epochCoordinator.send(ReportPartitionOffset(partitionId, epoch, dummyOffset)) + } + + private def makeSynchronousCall(): Unit = { + epochCoordinator.askSync[Long](GetCurrentEpoch) + } + + private def verifyCommit(epoch: Long): Unit = { + orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(query).commit(epoch) + } + + private def verifyNoCommitFor(epoch: Long): Unit = { + verify(writer, never()).commit(eqTo(epoch), any()) + verify(query, never()).commit(epoch) + } + + private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { + epochs.foreach(verifyCommit) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala new file mode 100644 index 000000000000..af4618bed545 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.sources + +import java.util.Optional + +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +case class FakeReader() extends MicroBatchReader with ContinuousReader { + def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} + def getStartOffset: Offset = RateStreamOffset(Map()) + def getEndOffset: Offset = RateStreamOffset(Map()) + def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + def commit(end: Offset): Unit = {} + def readSchema(): StructType = StructType(Seq()) + def stop(): Unit = {} + def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + def setStartOffset(start: Optional[Offset]): Unit = {} + + def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { + throw new IllegalStateException("fake source - cannot actually read") + } +} + +trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = FakeReader() +} + +trait FakeContinuousReadSupport extends ContinuousReadSupport { + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = FakeReader() +} + +trait FakeStreamWriteSupport extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { + override def shortName(): String = "fake-read-microbatch-only" +} + +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-continuous-only" +} + +class FakeReadBothModes extends DataSourceRegister + with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-microbatch-continuous" +} + +class FakeReadNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-read-neither-mode" +} + +class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" +} + +class FakeNoWrite extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" +} + + +case class FakeWriteV1FallbackException() extends Exception + +class FakeSink extends Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} +} + +class FakeWriteV1Fallback extends DataSourceRegister + with FakeStreamWriteSupport with StreamSinkProvider { + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + new FakeSink() + } + + override def shortName(): String = "fake-write-v1-fallback" +} + + +class StreamingDataSourceV2Suite extends StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + val fakeCheckpoint = Utils.createTempDir() + spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) + } + + val readFormats = Seq( + "fake-read-microbatch-only", + "fake-read-continuous-only", + "fake-read-microbatch-continuous", + "fake-read-neither-mode") + val writeFormats = Seq( + "fake-write-microbatch-continuous", + "fake-write-neither-mode") + val triggers = Seq( + Trigger.Once(), + Trigger.ProcessingTime(1000), + Trigger.Continuous(1000)) + + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + query.stop() + query + } + + private def testNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val ex = intercept[UnsupportedOperationException] { + testPositiveCase(readFormat, writeFormat, trigger) + } + assert(ex.getMessage.contains(errorMsg)) + } + + private def testPostCreationNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + + eventually(timeout(streamingTimeout)) { + assert(query.exception.isDefined) + assert(query.exception.get.cause != null) + assert(query.exception.get.cause.getMessage.contains(errorMsg)) + } + } + + test("disabled v2 write") { + // Ensure the V2 path works normally and generates a V2 sink.. + val v2Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeWriteV1Fallback]) + + // Ensure we create a V1 sink with the config. Note the config is a comma separated + // list, including other fake entries. + val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { + val v1Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeSink]) + } + } + + // Get a list of (read, write, trigger) tuples for test cases. + val cases = readFormats.flatMap { read => + writeFormats.flatMap { write => + triggers.map(t => (write, t)) + }.map { + case (write, t) => (read, write, t) + } + } + + for ((read, write, trigger) <- cases) { + testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + (readSource, writeSource, trigger) match { + // Valid microbatch queries. + case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) + if !t.isInstanceOf[ContinuousTrigger] => + testPositiveCase(read, write, trigger) + + // Valid continuous queries. + case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => + testPositiveCase(read, write, trigger) + + // Invalid - can't read at all + case (r, _, _) + if !r.isInstanceOf[MicroBatchReadSupport] + && !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support streamed reading") + + // Invalid - can't write + case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger is continuous but reader is not + case (r, _: StreamWriteSupport, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support continuous processing") + + // Invalid - trigger is microbatch but reader is not + case (r, _, t) + if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c3..8212fb912ec5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -422,21 +422,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink can be correctly loaded") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream - .format("console") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(2.seconds)) - .start() - - sq.awaitTermination(2000L) - } - test("prevent all column partitioning") { withTempDir { dir => val path = dir.getCanonicalPath @@ -450,16 +435,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink should not require checkpointLocation") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream.format("console").start() - sq.stop() - } - private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8c9bb7d56a35..14b1feb2adc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -301,7 +301,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be intercept[AnalysisException] { spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path) } - spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) } } @@ -563,7 +562,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - sql("CREATE TABLE same_name(id LONG) USING parquet") + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") spark.range(10).createTempView("same_name") spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 904f9f2ad0b2..bc4a120f7042 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -254,13 +254,26 @@ private[sql] trait SQLTestUtilsBase } /** - * Drops temporary table `tableName` after calling `f`. + * Drops temporary view `viewNames` after calling `f`. */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + protected def withTempView(viewNames: String*)(f: => Unit): Unit = { try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { + // temp views that never got created. + try viewNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops global temporary view `viewNames` after calling `f`. + */ + protected def withGlobalTempView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // global temp views that never got created. + try viewNames.foreach(spark.catalog.dropGlobalTempView) catch { case _: NoSuchTableException => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f549..e6c7648c986a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,25 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { + + /** + * Suites extending [[SharedSQLContext]] are sharing resources (eg. SparkSession) in their tests. + * That trait initializes the spark session in its [[beforeAll()]] implementation before the + * automatic thread snapshot is performed, so the audit code could fail to report threads leaked + * by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + super.afterAll() + doThreadPostAudit() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 0b4629a51b42..8968dbf36d50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -60,6 +60,7 @@ trait SharedSparkSession protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() new TestSparkSession(sparkConf) } @@ -92,11 +93,22 @@ trait SharedSparkSession * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } } } @@ -111,7 +123,7 @@ trait SharedSparkSession spark.sharedState.cacheManager.clearCache() // files can be closed from other threads, so wait a bit // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(2.seconds)) { DebugFilesystem.assertNoOpenStreams() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 4286e8a6ca2c..17603deacdcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -34,6 +34,9 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) this(new SparkConf) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + @transient override lazy val sessionState: SessionState = { new TestSQLSessionStateBuilder(this, None).build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala similarity index 66% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala index 17614e040e58..4019c6888da9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala @@ -14,17 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +package org.apache.spark.sql.test -/** - * Represents a step in configuring the Spark driver pod. - */ -private[spark] trait DriverConfigurationStep { +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession - /** - * Apply some transformation to the previous state of the driver to add a new feature to it. - */ - def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +class TestSparkSessionSuite extends SparkFunSuite { + test("default session is set in constructor") { + val session = new TestSparkSession() + assert(SparkSession.getDefaultSession.contains(session)) + session.stop() + } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3135a8a275da..9f247f9224c7 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade6528304..10000f12ab32 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 47bfaa86021d..f59cdcd3188e 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -44,7 +44,7 @@ import org.apache.hadoop.hive.ql.history.HiveHistory; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.processors.SetProcessor; +import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hive.common.util.HiveVersionInfo; @@ -71,6 +71,12 @@ import org.apache.hive.service.cli.thrift.TProtocolVersion; import org.apache.hive.service.server.ThreadWithGarbageCleanup; +import static org.apache.hadoop.hive.conf.SystemVariables.ENV_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVECONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVEVAR_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.METACONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.SYSTEM_PREFIX; + /** * HiveSession * @@ -209,7 +215,7 @@ private void configureSession(Map sessionConfMap) throws HiveSQL String key = entry.getKey(); if (key.startsWith("set:")) { try { - SetProcessor.setVariable(key.substring(4), entry.getValue()); + setVariable(key.substring(4), entry.getValue()); } catch (Exception e) { throw new HiveSQLException(e); } @@ -221,8 +227,84 @@ private void configureSession(Map sessionConfMap) throws HiveSQL } } + // Copy from org.apache.hadoop.hive.ql.processors.SetProcessor, only change: + // setConf(varname, propName, varvalue, true) when varname.startsWith(HIVECONF_PREFIX) + public static int setVariable(String varname, String varvalue) throws Exception { + SessionState ss = SessionState.get(); + if (varvalue.contains("\n")){ + ss.err.println("Warning: Value had a \\n character in it."); + } + varname = varname.trim(); + if (varname.startsWith(ENV_PREFIX)){ + ss.err.println("env:* variables can not be set."); + return 1; + } else if (varname.startsWith(SYSTEM_PREFIX)){ + String propName = varname.substring(SYSTEM_PREFIX.length()); + System.getProperties().setProperty(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(HIVECONF_PREFIX)){ + String propName = varname.substring(HIVECONF_PREFIX.length()); + setConf(varname, propName, varvalue, true); + } else if (varname.startsWith(HIVEVAR_PREFIX)) { + String propName = varname.substring(HIVEVAR_PREFIX.length()); + ss.getHiveVariables().put(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(METACONF_PREFIX)) { + String propName = varname.substring(METACONF_PREFIX.length()); + Hive hive = Hive.get(ss.getConf()); + hive.setMetaConf(propName, new VariableSubstitution().substitute(ss.getConf(), varvalue)); + } else { + setConf(varname, varname, varvalue, true); + } + return 0; + } + + // returns non-null string for validation fail + private static void setConf(String varname, String key, String varvalue, boolean register) + throws IllegalArgumentException { + HiveConf conf = SessionState.get().getConf(); + String value = new VariableSubstitution().substitute(conf, varvalue); + if (conf.getBoolVar(HiveConf.ConfVars.HIVECONFVALIDATION)) { + HiveConf.ConfVars confVars = HiveConf.getConfVars(key); + if (confVars != null) { + if (!confVars.isType(value)) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED because ").append(key).append(" expects "); + message.append(confVars.typeString()).append(" type value."); + throw new IllegalArgumentException(message.toString()); + } + String fail = confVars.validate(value); + if (fail != null) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED in validation : ").append(fail).append('.'); + throw new IllegalArgumentException(message.toString()); + } + } else if (key.startsWith("hive.")) { + throw new IllegalArgumentException("hive configuration " + key + " does not exists."); + } + } + conf.verifyAndSet(key, value); + if (register) { + SessionState.get().getOverriddenConfigurations().put(key, value); + } + } + @Override public void setOperationLogSessionDir(File operationLogRootDir) { + if (!operationLogRootDir.exists()) { + LOG.warn("The operation log root directory is removed, recreating: " + + operationLogRootDir.getAbsolutePath()); + if (!operationLogRootDir.mkdirs()) { + LOG.warn("Unable to create operation log root directory: " + + operationLogRootDir.getAbsolutePath()); + } + } + if (!operationLogRootDir.canWrite()) { + LOG.warn("The operation log root directory is not writable: " + + operationLogRootDir.getAbsolutePath()); + } sessionLogDir = new File(operationLogRootDir, sessionHandle.getHandleIdentifier().toString()); isOperationLogEnabled = true; if (!sessionLogDir.exists()) { @@ -559,6 +641,8 @@ public void close() throws HiveSQLException { opHandleSet.clear(); // Cleanup session log directory. cleanupSessionLogDir(); + // Cleanup pipeout file. + cleanupPipeoutFile(); HiveHistory hiveHist = sessionState.getHiveHistory(); if (null != hiveHist) { hiveHist.closeStream(); @@ -583,6 +667,22 @@ public void close() throws HiveSQLException { } } + private void cleanupPipeoutFile() { + String lScratchDir = hiveConf.getVar(ConfVars.LOCALSCRATCHDIR); + String sessionID = hiveConf.getVar(ConfVars.HIVESESSIONID); + + File[] fileAry = new File(lScratchDir).listFiles( + (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); + + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } + } + } + private void cleanupSessionLogDir() { if (isOperationLogEnabled) { try { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601ea..3cfc81b8a957 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -102,7 +102,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) - case _: ArrayType | _: StructType | _: MapType => + case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => val hiveString = HiveUtils.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 832a15d09599..084f8200102b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,13 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HiveDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -121,6 +123,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { } } + val tokenProvider = new HiveDelegationTokenProvider() + if (tokenProvider.delegationTokensRequired(sparkConf, hadoopConf)) { + val credentials = new Credentials() + tokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + UserGroupInformation.getCurrentUser.addCredentials(credentials) + } + SessionState.start(sessionState) // Clean up after we exit diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a..1335e16e3588 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 6b19f971b73b..8980bcf88558 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,8 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - .client.newSession() + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 48c0ebef3e0c..2958b771f364 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -40,22 +40,8 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - // Create operation log root directory, if operation logging is enabled - if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { - invoke(classOf[SessionManager], this, "initOperationLogRootDir") - } - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) + super.init(hiveConf) } override def openSession( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a0e5012633f5..bf7c01f60fb5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.internal.SQLConf /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -50,6 +51,9 @@ private[thriftserver] class SparkSQLOperationManager() require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") val conf = sqlContext.sessionState.conf + val hiveSessionState = parentSession.getSessionState + setConfMap(conf, hiveSessionState.getOverriddenConfigurations) + setConfMap(conf, hiveSessionState.getHiveVariables) val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) @@ -58,4 +62,12 @@ private[thriftserver] class SparkSQLOperationManager() s"runInBackground=$runInBackground") operation } + + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { + val iterator = confMap.entrySet().iterator() + while (iterator.hasNext) { + val kv = iterator.next() + conf.setConfString(kv.getKey, kv.getValue) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index 3f135cc86498..277df548aefd 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -38,6 +38,20 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { // We should be able to run Spark jobs without Hive client. assert(spark.sparkContext.range(0, 1).count() === 1) + // We should be able to use Spark SQL if no table references. + assert(spark.sql("select 1 + 1").count() === 1) + assert(spark.range(0, 1).count() === 1) + + // We should be able to use fs + val path = Utils.createTempDir() + path.delete() + try { + spark.range(0, 1).write.parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).count() === 1) + } finally { + Utils.deleteRecursively(path) + } + // Make sure that we are not using the local derby metastore. val exceptionString = Utils.exceptionString(intercept[AnalysisException] { spark.sql("show tables") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 7289da71a336..192f33a45e27 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import java.io.File +import java.io.{File, FilenameFilter} import java.net.URL import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -135,6 +136,22 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("Support beeline --hiveconf and --hivevar") { + withJdbcStatement() { statement => + executeTest(hiveConfList) + executeTest(hiveVarList) + def executeTest(hiveList: String): Unit = { + hiveList.split(";").foreach{ m => + val kv = m.split("=") + // select "${a}"; ---> avalue + val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"") + resultSet.next() + assert(resultSet.getString(1) === kv(1)) + } + } + } + } + test("JDBC query execution") { withJdbcStatement("test") { statement => val queries = Seq( @@ -597,6 +614,28 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { bufferSrc.close() } } + + test("SPARK-23547 Cleanup the .pipeout file when the Hive Session closed") { + def pipeoutFileList(sessionID: UUID): Array[File] = { + lScratchDir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + name.startsWith(sessionID.toString) && name.endsWith(".pipeout") + } + }) + } + + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + val sessionID = sessionHandle.getSessionId + + assert(pipeoutFileList(sessionID).length == 1) + + client.closeSession(sessionHandle) + + assert(pipeoutFileList(sessionID).length == 0) + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -740,10 +779,11 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { s"""jdbc:hive2://localhost:$serverPort/ |default? |hive.server2.transport.mode=http; - |hive.server2.thrift.http.path=cliservice + |hive.server2.thrift.http.path=cliservice; + |${hiveConfList}#${hiveVarList} """.stripMargin.split("\n").mkString.trim } else { - s"jdbc:hive2://localhost:$serverPort/" + s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}" } def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) { @@ -779,15 +819,18 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private var listeningPort: Int = _ protected def serverPort: Int = listeningPort + protected val hiveConfList = "a=avalue;b=bvalue" + protected val hiveVarList = "c=cvalue;d=dvalue" protected def user = System.getProperty("user.name") protected var warehousePath: File = _ protected var metastorePath: File = _ protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ + protected var lScratchDir: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -825,6 +868,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath + | --hiveconf ${ConfVars.LOCALSCRATCHDIR}=$lScratchDir | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -854,6 +898,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() operationLogPath = Utils.createTempDir() operationLogPath.delete() + lScratchDir = Utils.createTempDir() + lScratchDir.delete() logPath = null logTailingProcess = null @@ -937,6 +983,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl operationLogPath.delete() operationLogPath = null + lScratchDir.delete() + lScratchDir = null + Option(logPath).foreach(_.delete()) logPath = null diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 45791c69b4cb..cebaad5b4ad9 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -62,7 +62,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll() { diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 66fad85ea026..c55ba32fa458 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 632e3e0c4c3f..011a3ba553cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -109,8 +109,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Get the raw table metadata from hive metastore directly. The raw table metadata may contains - * special data source properties and should not be exposed outside of `HiveExternalCatalog`. We + * Get the raw table metadata from hive metastore directly. The raw table metadata may contain + * special data source properties that should not be exposed outside of `HiveExternalCatalog`. We * should interpret these special data source properties and restore the original table metadata * before returning it. */ @@ -158,13 +158,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +480,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +489,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +624,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,21 +656,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert table statistics to properties so that we can persist them through hive client val statsProperties = if (stats.isDefined) { - statsToProperties(stats.get, schema) + statsToProperties(stats.get) } else { new mutable.HashMap[String, String]() } @@ -1028,9 +1024,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } - private def statsToProperties( - stats: CatalogStatistics, - schema: StructType): Map[String, String] = { + private def statsToProperties(stats: CatalogStatistics): Map[String, String] = { val statsProperties = new mutable.HashMap[String, String]() statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() @@ -1038,11 +1032,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - val colNameTypeMap: Map[String, DataType] = - schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + colStat.toMap(colName).foreach { case (k, v) => + // Fully qualified name used in table properties for a particular column stat. + // For example, for column "mycol", and "min" stat, this should return + // "spark.sql.statistics.colStats.mycol.min". + statsProperties += (STATISTICS_COL_STATS_PREFIX + k -> v) } } @@ -1058,23 +1053,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (statsProps.isEmpty) { None } else { + val colStats = new mutable.HashMap[String, CatalogColumnStat] + val colStatsProps = properties.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)).map { + case (k, v) => k.drop(STATISTICS_COL_STATS_PREFIX.length) -> v + } - val colStats = new mutable.HashMap[String, ColumnStat] - - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } - ColumnStat.fromMap(table, field, colStatMap).foreach { cs => - colStats += field.name -> cs - } + // Find all the column names by matching the KEY_VERSION properties for them. + colStatsProps.keys.filter { + k => k.endsWith(CatalogColumnStat.KEY_VERSION) + }.map { k => + k.dropRight(CatalogColumnStat.KEY_VERSION.length + 1) + }.foreach { fieldName => + // and for each, create a column stat. + CatalogColumnStat.fromMap(table, fieldName, colStatsProps).foreach { cs => + colStats += fieldName -> cs } } @@ -1093,25 +1085,16 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert partition statistics to properties so that we can persist them through hive api val withStatsProps = lowerCasedParts.map { p => if (p.stats.isDefined) { - val statsProperties = statsToProperties(p.stats.get, schema) + val statsProperties = statsToProperties(p.stats.get) p.copy(parameters = p.parameters ++ statsProperties) } else { p } } - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - client.setCurrentDatabase(db) client.alterPartitions(db, table, withStatsProps) } @@ -1225,7 +1208,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1238,12 +1221,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1252,7 +1235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { @@ -1315,15 +1298,6 @@ object HiveExternalCatalog { val EMPTY_DATA_SCHEMA = new StructType() .add("col", "array", nullable = true, comment = "from deserializer") - /** - * Returns the fully qualified name used in table properties for a particular column stat. - * For example, for column "mycol", and "min" stat, this should return - * "spark.sql.statistics.colStats.mycol.min". - */ - private def columnStatKeyPropName(columnName: String, statKey: String): String = { - STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey - } - // A persisted data source table always store its schema in the catalog. private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1f11adbd4f62..94ddeae1bf54 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,8 +39,8 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalog: HiveExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, @@ -48,8 +48,8 @@ private[sql] class HiveSessionCatalog( parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( - externalCatalog, - globalTempViewManager, + externalCatalogBuilder, + globalTempViewManagerBuilder, functionRegistry, conf, hadoopConf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e..2882672f327c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,15 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() - new HiveSessionResourceLoader(session, client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** @@ -51,8 +51,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session */ override protected lazy val catalog: HiveSessionCatalog = { val catalog = new HiveSessionCatalog( - externalCatalog, - session.sharedState.globalTempViewManager, + () => externalCatalog, + () => session.sharedState.globalTempViewManager, new HiveMetastoreCatalog(session), functionRegistry, conf, @@ -96,22 +96,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = - super.extraPlanningStrategies ++ customPlanningStrategies - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ - extraPlanningStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy(conf), - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } + super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts) } } @@ -120,8 +105,9 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session class HiveSessionResourceLoader( session: SparkSession, - client: HiveClient) + clientBuilder: () => HiveClient) extends SessionResourceLoader(session) { + private lazy val client = clientBuilder() override def addJar(path: String): Unit = { client.addJar(path) super.addJar(path) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ab857b905572..a0c197b06dda 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -157,7 +157,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd..10c960374537 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.1.1.") + s"0.12.0 through 2.3.2.") .stringConf .createWithDefault(builtinHiveVersion) @@ -109,7 +109,7 @@ private[spark] object HiveUtils extends Logging { .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + @@ -304,7 +304,7 @@ private[spark] object HiveUtils extends Logging { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"Specify a valid path to the correct hive jars using ${HIVE_METASTORE_JARS.key} " + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) @@ -322,7 +324,7 @@ private[spark] object HiveUtils extends Logging { if (jars.length == 0) { throw new IllegalArgumentException( "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") + s"Please set ${HIVE_METASTORE_JARS.key}.") } logInfo( @@ -458,6 +460,7 @@ private[spark] object HiveUtils extends Logging { case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString + case (other, _ : UserDefinedType[_]) => other.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc9..b5444a421792 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -381,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal - }.unzip + }.toArray.unzip /** * Builds specific unwrappers ahead of time according to object inspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7b7f4e0f1021..da9fe2d3088b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} -import java.util.Locale +import java.lang.{Iterable => JIterable} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf @@ -82,8 +82,9 @@ import org.apache.spark.util.{CircularBuffer, Utils} */ private[hive] class HiveClientImpl( override val version: HiveVersion, + warehouseDir: Option[String], sparkConf: SparkConf, - hadoopConf: Configuration, + hadoopConf: JIterable[JMap.Entry[String, String]], extraConfig: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) @@ -102,6 +103,8 @@ private[hive] class HiveClientImpl( case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() case hive.v2_1 => new Shim_v2_1() + case hive.v2_2 => new Shim_v2_2() + case hive.v2_3 => new Shim_v2_3() } // Create an internal session state for this HiveClientImpl. @@ -130,7 +133,7 @@ private[hive] class HiveClientImpl( if (ret != null) { // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState // instance constructed, we need to follow that change here. - Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + warehouseDir.foreach { dir => ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) } ret @@ -289,14 +292,20 @@ private[hive] class HiveClientImpl( state.err = stream } - override def setCurrentDatabase(databaseName: String): Unit = withHiveState { - if (databaseExists(databaseName)) { - state.setCurrentDatabase(databaseName) - } else { - throw new NoSuchDatabaseException(databaseName) + private def setCurrentDatabaseRaw(db: String): Unit = { + if (state.getCurrentDatabase != db) { + if (databaseExists(db)) { + state.setCurrentDatabase(db) + } else { + throw new NoSuchDatabaseException(db) + } } } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + setCurrentDatabaseRaw(databaseName) + } + override def createDatabase( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { @@ -330,7 +339,7 @@ private[hive] class HiveClientImpl( Option(client.getDatabase(dbName)).map { d => CatalogDatabase( name = d.getName, - description = d.getDescription, + description = Option(d.getDescription).getOrElse(""), locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) @@ -596,8 +605,18 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(userName)) - shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + val original = state.getCurrentDatabase + try { + setCurrentDatabaseRaw(db) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) + shim.alterPartitions(client, table, newParts.map { toHivePartition(_, hiveTable) }.asJava) + } finally { + state.setCurrentDatabase(original) + } } /** @@ -824,19 +843,19 @@ private[hive] class HiveClientImpl( def reset(): Unit = withHiveState { client.getAllTables("default").asScala.foreach { t => - logDebug(s"Deleting table $t") - val table = client.getTable("default", t) - client.getIndexes("default", t, 255).asScala.foreach { index => - shim.dropIndex(client, "default", t, index.getIndexName) - } - if (!table.isIndexTable) { - client.dropTable("default", t) - } + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) } - client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - client.dropDatabase(db, true, false, true) + if (!table.isIndexTable) { + client.dropTable("default", t) } + } + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19c..948ba542b573 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -880,9 +880,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { } -private[client] class Shim_v1_0 extends Shim_v0_14 { - -} +private[client] class Shim_v1_0 extends Shim_v0_14 private[client] class Shim_v1_1 extends Shim_v1_0 { @@ -1146,3 +1144,7 @@ private[client] class Shim_v2_1 extends Shim_v2_0 { alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) } } + +private[client] class Shim_v2_2 extends Shim_v2_1 + +private[client] class Shim_v2_3 extends Shim_v2_1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7a76fd3fd2eb..c2690ec32b9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkSubmitUtils @@ -48,11 +49,12 @@ private[hive] object IsolatedClientLoader extends Logging { config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, - barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { + barrierPrefixes: Seq[String] = Seq.empty, + sharesHadoopClasses: Boolean = true): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(hiveMetastoreVersion) // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact // with the given version, we will use Hadoop 2.6 and then will not share Hadoop classes. - var sharesHadoopClasses = true + var _sharesHadoopClasses = sharesHadoopClasses val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { resolvedVersions((resolvedVersion, hadoopVersion)) } else { @@ -68,7 +70,7 @@ private[hive] object IsolatedClientLoader extends Logging { "Hadoop classes will not be shared between Spark and Hive metastore client. " + "It is recommended to set jars used by Hive metastore client through " + "spark.sql.hive.metastore.jars in the production environment.") - sharesHadoopClasses = false + _sharesHadoopClasses = false (downloadVersion(resolvedVersion, "2.6.5", ivyPath), "2.6.5") } resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) @@ -81,7 +83,7 @@ private[hive] object IsolatedClientLoader extends Logging { execJars = files, hadoopConf = hadoopConf, config = config, - sharesHadoopClasses = sharesHadoopClasses, + sharesHadoopClasses = _sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -95,6 +97,8 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.2" | "1.2.0" | "1.2.1" | "1.2.2" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 + case "2.2" | "2.2.0" => hive.v2_2 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 } private def downloadVersion( @@ -175,8 +179,9 @@ private[hive] class IsolatedClientLoader( val isHadoopClass = name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") - name.contains("slf4j") || - name.contains("log4j") || + name.startsWith("org.slf4j") || + name.startsWith("org.apache.log4j") || // log4j1.x + name.startsWith("org.apache.logging.log4j") || // log4j2 name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || @@ -249,8 +254,10 @@ private[hive] class IsolatedClientLoader( /** The isolated client interface to Hive. */ private[hive] def createClient(): HiveClient = synchronized { + val warehouseDir = Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)) if (!isolationOn) { - return new HiveClientImpl(version, sparkConf, hadoopConf, config, baseClassLoader, this) + return new HiveClientImpl(version, warehouseDir, sparkConf, hadoopConf, config, + baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -261,7 +268,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head - .newInstance(version, sparkConf, hadoopConf, config, classLoader, this) + .newInstance(version, warehouseDir, sparkConf, hadoopConf, config, classLoader, this) .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index c14154a3b3c2..681ee9200f02 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -71,7 +71,15 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) + case object v2_2 extends HiveVersion("2.2.0", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v2_3 extends HiveVersion("2.3.2", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 65e8b4e3c725..1e801fe1845c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand /** @@ -36,15 +37,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, + outputColumns: Seq[Attribute], mode: SaveMode) - extends RunnableCommand { + extends DataWritingCommand { private val tableIdentifier = tableDesc.identifier - override def innerChildren: Seq[LogicalPlan] = Seq(query) - - override def run(sparkSession: SparkSession): Seq[Row] = { - if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + if (catalog.tableExists(tableIdentifier)) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -56,34 +57,36 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = false, - ifPartitionNotExists = false)).toRdd + InsertIntoHiveTable( + tableDesc, + Map.empty, + query, + overwrite = false, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - sparkSession.sessionState.catalog.createTable( - tableDesc.copy(schema = query.schema), ignoreIfExists = false) + catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) try { - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = true, - ifPartitionNotExists = false)).toRdd + // Read back the metadata of the table which was created just now. + val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) + // For CTAS, there is no static partition values to insert. + val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + createdTableMeta, + partition, + query, + overwrite = true, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. - sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, - purge = false) + catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false) throw e } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9..802ddafdbee4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6f..02a60f16b3b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -172,7 +172,7 @@ case class InsertIntoHiveTable( val enforceBucketingConfig = "hive.enforce.bucketing" val enforceSortingConfig = "hive.enforce.sorting" - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + + val message = s"Output Hive table ${table.identifier} is bucketed but Spark " + "currently does NOT populate bucketed output which is compatible with Hive." if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6..e0f7375387d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,18 +55,28 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } val committer = FileCommitProtocol.instantiate( @@ -104,14 +114,15 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = + Set(v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 95741c7b3028..237ed9bc0598 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -59,9 +59,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles OrcFileOperator.readSchema( files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()) + Some(sparkSession.sessionState.newHadoopConf()), + ignoreCorruptFiles ) } @@ -129,6 +131,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -138,7 +141,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty + val isEmptyFile = + OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf), ignoreCorruptFiles).isEmpty if (isEmptyFile) { Iterator.empty } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 5a3fcd7a759c..80e44ca50435 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.orc +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -46,7 +49,10 @@ private[hive] object OrcFileOperator extends Logging { * create the result reader from that file. If no such file is found, it returns `None`. * @todo Needs to consider all files when schema evolution is taken into account. */ - def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = { + def getFileReader(basePath: String, + config: Option[Configuration] = None, + ignoreCorruptFiles: Boolean = false) + : Option[Reader] = { def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = { reader.getObjectInspector match { case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 => @@ -65,16 +71,28 @@ private[hive] object OrcFileOperator extends Logging { } listOrcFiles(basePath, conf).iterator.map { path => - path -> OrcFile.createReader(fs, path) + val reader = try { + Some(OrcFile.createReader(fs, path)) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $path", e) + } + } + path -> reader }.collectFirst { - case (path, reader) if isWithNonEmptySchema(path, reader) => reader + case (path, Some(reader)) if isWithNonEmptySchema(path, reader) => reader } } - def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = { + def readSchema(paths: Seq[String], conf: Option[Configuration], ignoreCorruptFiles: Boolean) + : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf)).headOption.map { reader => + paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b6be00dbb3a7..ee3f99ab7e9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -175,12 +176,21 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", // scratch directory used by Hive's metastore client ConfVars.SCRATCHDIR.varname -> TestHiveContext.makeScratchDir().toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") ++ + // After session cloning, the JDBC connect string for a JDBC metastore should not be changed. + existingSharedState.map { state => + val connKey = + state.sparkContext.hadoopConfiguration.get(ConfVars.METASTORECONNECTURLKEY.varname) + ConfVars.METASTORECONNECTURLKEY.varname -> connKey + } metastoreTempConf.foreach { case (k, v) => sc.hadoopConfiguration.set(k, v) @@ -199,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) @@ -486,8 +498,7 @@ private[hive] class TestHiveSparkSession( protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** - * Resets the test instance by deleting any tables that have been created. - * TODO: also clear out UDFs, views, etc. + * Resets the test instance by deleting any table, view, temp view, and UDF that have been created */ def reset() { try { @@ -513,8 +524,9 @@ private[hive] class TestHiveSparkSession( // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to // delete it. Later, it will be re-created with the right permission. - val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) - val fs = location.getFileSystem(sc.hadoopConfiguration) + val hadoopConf = sessionState.newHadoopConf() + val location = new Path(hadoopConf.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(hadoopConf) fs.delete(location, true) // Some tests corrupt this value on purpose, which breaks the RESET call below. @@ -525,8 +537,6 @@ private[hive] class TestHiveSparkSession( // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 metadataHive.runSqlHive("set hive.table.parameters.default=") - metadataHive.runSqlHive("set datanucleus.cache.collections=true") - metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") @@ -570,7 +580,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index d9cf1f361c1d..61f9179042fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -108,10 +108,7 @@ class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { } test("window specification") { - val frame = SpecifiedWindowFrame.defaultWindowFrame( - hasOrderSpecification = true, - acceptWindowFrame = true - ) + val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a617..569f00c053e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 000000000000..4550d350f6db --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} +import org.apache.spark.sql.hive.orc.OrcFileOperator +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 50 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) + } + } + + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format.toLowerCase match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter { file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + codecs.distinct + } + + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + } + + private def writeDataToTable( + tableName: String, + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString + sql( + s""" + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) + } + + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + totalSize = getTableSize(path) + } + } + } + assert(totalSize > 0L) + totalSize + } + + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String], + usingCTAS: Boolean) + (assertion: (String, Long) => Unit): Unit = { + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" + withTempDir { tmpDir => + withTable(tableName) { + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) + val tableSize = getTableSize(path) + assertion(relCompressionCodecs.head, tableSize) + } + } + } + + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } + } + } + + // When the amount of data is small, compressed data size may be larger than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + tableSize: Long): Boolean = { + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize + } + } + + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + val expectCodec = tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + test("both table-level and session-level compression are set") { + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + test("table-level compression is not set but session-level compressions is set ") { + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partitionValue = if (isPartitioned) Some(compressionCodec) else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partitionValue) } + } + val tablePath = getTablePartitionPath(tmpDir, tableName, None) + val realCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else { + getTableCompressionCodec(tablePath, format) + } + + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } + } + } + } + + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 8a7423663f28..a80db765846e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach { + override protected val enableAutoThreadAudit = false private var sc: SparkContext = null private var hc: HiveContext = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 2e35fdeba464..0a522b6a11c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -107,4 +107,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { .filter(_.contains("Num Buckets")).head assert(bucketString.contains("10")) } + + test("SPARK-23001: NullPointerException when running desc database") { + val catalog = newBasicCatalog() + catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) + assert(catalog.getDatabase("dbWithNullDesc").description == "") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a676..ea86ab9772bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,34 +58,67 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { - return + try { + getFileFromUrl(url, path, filename) + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } + private def genDataDir(name: String): String = { + new File(tmpDataDir, name).getCanonicalPath + } - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } - Seq("mkdir", targetDir).! + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() - Seq("rm", downloaded).! - } + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) - private def genDataDir(name: String): String = { - new File(tmpDataDir, name).getCanonicalPath + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) } override def beforeAll(): Unit = { @@ -125,7 +161,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( @@ -159,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") protected var spark: SparkSession = _ @@ -213,7 +249,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { spark.sql("msck repair table " + tbl_with_col_overlap) assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0ea..fd5f47e42823 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 18137e7ea1d6..ba9b944e4a05 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -146,6 +146,12 @@ class DataSourceWithHiveMetastoreCatalogSuite 'id cast StringType as 'd2 ).coalesce(1) + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.sessionState.catalog.reset() + sparkSession.metadataHive.reset() + } + Seq( "parquet" -> (( "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d..51a48a20daaa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 958ad3e1c3ce..a3579862c9e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterEach +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -25,11 +25,11 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton /** * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ -class HiveSessionStateSuite extends SessionStateSuite - with TestHiveSingleton with BeforeAndAfterEach { +class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { override def beforeAll(): Unit = { // Reuse the singleton session + super.beforeAll() activeSession = spark } @@ -38,4 +38,16 @@ class HiveSessionStateSuite extends SessionStateSuite activeSession = null super.afterAll() } + + test("Clone then newSession") { + val sparkSession = hiveContext.sparkSession + val conf = sparkSession.sparkContext.hadoopConfiguration + val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + sparkSession.cloneSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() + val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + assert(oldValue == newValue, + "cloneSession and then newSession should not affect the Derby directory") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 21b3e281490c..aa5b53199261 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -44,11 +44,12 @@ class HiveSparkSubmitSuite with BeforeAndAfterEach with ResetSystemProperties { + override protected val enableAutoThreadAudit = false + // TODO: rewrite these or mark them as slow tests to be run sparingly override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } test("temporary Hive UDF: define a UDF and use it") { @@ -329,7 +330,7 @@ class HiveSparkSubmitSuite object SetMetastoreURLTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true) val builder = SparkSession.builder() @@ -353,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -367,7 +368,7 @@ object SetMetastoreURLTest extends Logging { object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = @@ -446,7 +447,7 @@ object SetWarehouseLocationTest extends Logging { // can load the jar defined with the function. object TemporaryHiveUDFTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -484,7 +485,7 @@ object TemporaryHiveUDFTest extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest1 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -522,7 +523,7 @@ object PermanentHiveUDFTest1 extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest2 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -557,7 +558,7 @@ object PermanentHiveUDFTest2 extends Logging { // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set("spark.ui.enabled", "false") @@ -627,7 +628,7 @@ object SparkSubmitClassLoaderTest extends Logging { // We test if we can correctly set spark sql configurations when HiveContext is used. object SparkSQLConfTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") // We override the SparkConf to add spark.sql.hive.metastore.version and // spark.sql.hive.metastore.jars to the beginning of the conf entry array. // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but @@ -668,7 +669,7 @@ object SPARK_9757 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( @@ -717,7 +718,7 @@ object SPARK_11009 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -748,7 +749,7 @@ object SPARK_14244 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -779,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a6844..f2b75e4b23f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SQLTestUtils} +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,25 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } + + test("toHiveString correctly handles UDTs") { + val point = new ExamplePoint(50.0, 50.0) + val tpe = new ExamplePointUDT() + assert(HiveUtils.toHiveString((point, tpe)) === "(50.0, 50.0)") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c8caba83bf36..d93215fefb81 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,14 +23,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.command.CreateTableCommand import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ @@ -593,7 +591,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (ArrayType)") { - withTable("arrayInParquet") { + withTable("array") { { val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") val expectedSchema = @@ -606,9 +604,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("arrayInParquet") + .saveAsTable("array") } { @@ -623,25 +620,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("arrayInParquet") + .insertInto("array") } (Tuple1(Seq(4, 5)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + .saveAsTable("array") // This one internally calls df2.insertInto. (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") + .saveAsTable("array") - sparkSession.catalog.refreshTable("arrayInParquet") + sparkSession.catalog.refreshTable("array") checkAnswer( - sql("SELECT a FROM arrayInParquet"), + sql("SELECT a FROM array"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -650,7 +646,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (MapType)") { - withTable("mapInParquet") { + withTable("map") { { val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") val expectedSchema = @@ -663,9 +659,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("mapInParquet") + .saveAsTable("map") } { @@ -680,27 +675,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("mapInParquet") + .insertInto("map") } (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + .saveAsTable("map") // This one internally calls df2.insertInto. (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") + .saveAsTable("map") - sparkSession.catalog.refreshTable("mapInParquet") + sparkSession.catalog.refreshTable("map") checkAnswer( - sql("SELECT a FROM mapInParquet"), + sql("SELECT a FROM map"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -854,52 +846,52 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv (from to to).map(i => i -> s"str$i").toDF("c1", "c2") } - withTable("insertParquet") { - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + withTable("t") { + createDF(0, 9).write.saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.saveAsTable("t") } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(20, 29).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") + createDF(30, 39).write.saveAsTable("t") } - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + createDF(40, 49).write.mode(SaveMode.Append).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (50 to 59).map(i => Row(i, s"str$i"))) - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (70 to 79).map(i => Row(i, s"str$i"))) } } @@ -1153,7 +1145,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("create a temp view using hive") { val tableName = "tab1" - withTable(tableName) { + withTempView(tableName) { val e = intercept[AnalysisException] { sql( s""" @@ -1344,18 +1336,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"SPARK-22146: read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" - withTempDir { dir => - val tmpFile = s"$dir/$nameWithSpecialChars" - spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - val fileContent = spark.read.format(format).load(tmpFile) - checkAnswer(fileContent, Seq(Row("a"), Row("b"))) - } - } - } - private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala new file mode 100644 index 000000000000..1e525c46a9cf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") + (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the number of matching partitions + assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5) + + // verify that the partition predicate was pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5) + } + } + + test("SPARK-23877: filter on projected expression") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the matching partitions + val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) + .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) + + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + + // verify that the partition predicate was not pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 9440a17677eb..80afc9d8f44b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -37,11 +37,11 @@ class PartitionProviderCompatibilitySuite spark.range(5).selectExpr("id as fieldOne", "id as partCol").write .partitionBy("partCol") .mode("overwrite") - .parquet(dir.getAbsolutePath) + .save(dir.getAbsolutePath) spark.sql(s""" |create table $tableName (fieldOne long, partCol int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${dir.toURI}") |partitioned by (partCol)""".stripMargin) } @@ -358,7 +358,7 @@ class PartitionProviderCompatibilitySuite try { spark.sql(s""" |create table test (id long, P1 int, P2 int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${base.toURI}") |partitioned by (P1, P2)""".stripMargin) spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.toURI}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 54d3962a46b4..3af163af0968 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -417,9 +417,9 @@ class PartitionedTablePerfStatsSuite import spark.implicits._ Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) HiveCatalogMetrics.reset() - spark.read.parquet(dir.getAbsolutePath) + spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) - assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index b2dc401ce1ef..1e396553c9c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem +import org.apache.spark.internal.config._ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -32,41 +33,53 @@ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ - test("SPARK-5068: query data when path doesn't exist") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") + private def queryWhenPathNotExist(): Unit = { + withTempView("testData") { + withTable("table_with_partition", "createAndInsertTest") { + withTempDir { tmpDir => + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData).union(testData)) - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData)) + } + } + } + } - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") { + queryWhenPathNotExist() + } + } - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") { + sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") + queryWhenPathNotExist() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474..473bbced41b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -289,7 +289,8 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3af8af0814bb..61cec8298479 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils @@ -177,8 +177,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats0.get.colStats == Map( - "a" -> ColumnStat(2, Some(1), Some(2), 0, 4, 4), - "b" -> ColumnStat(1, Some(1), Some(1), 0, 4, 4))) + "a" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(0), Some(4), Some(4)), + "b" -> CatalogColumnStat(Some(1), Some("1"), Some("1"), Some(0), Some(4), Some(4)))) } } @@ -208,8 +208,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "C1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "C1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) } } @@ -596,7 +596,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + assert(fetchedStats0.get.colStats == + Map("c1" -> CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)))) // Insert new data and analyze: have the latest column stats. sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") @@ -604,18 +605,18 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) // Analyze another column: since the table is not changed, the precious column stats are kept. sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats2.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4), - "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, - avgLen = 1, maxLen = 1))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c2" -> CatalogColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)))) // Insert new data and analyze: stale column stats are removed and newly collected column // stats are added. @@ -624,10 +625,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get assert(fetchedStats3.colStats == Map( - "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, - avgLen = 8, maxLen = 8))) + "c1" -> CatalogColumnStat(distinctCount = Some(2), min = Some("1"), max = Some("2"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(2), min = Some("10.0"), max = Some("20.0"), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)))) } } @@ -999,115 +1000,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("verify serialized column stats after analyzing columns") { import testImplicits._ - val tableName = "column_stats_test2" + val tableName = "column_stats_test_ser" // (data.head.productArity - 1) because the last column does not support stats collection. assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - val expectedSerializedColStats = Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - ) - - val expectedSerializedHistograms = Map( - "spark.sql.statistics.colStats.cbyte.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), - "spark.sql.statistics.colStats.cshort.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), - "spark.sql.statistics.colStats.cint.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), - "spark.sql.statistics.colStats.clong.histogram" -> - HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), - "spark.sql.statistics.colStats.cdouble.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), - "spark.sql.statistics.colStats.cfloat.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), - "spark.sql.statistics.colStats.cdecimal.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), - "spark.sql.statistics.colStats.cdate.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), - "spark.sql.statistics.colStats.ctimestamp.histogram" -> - HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) - ) - def checkColStatsProps(expected: Map[String, String]): Unit = { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) val table = hiveClient.getTable("default", tableName) @@ -1129,6 +1026,29 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("verify column stats can be deserialized from tblproperties") { + import testImplicits._ + + val tableName = "column_stats_test_de" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Put in stats properties manually. + val table = getCatalogTable(tableName) + val newTable = table.copy( + properties = table.properties ++ + expectedSerializedColStats ++ expectedSerializedHistograms + + ("spark.sql.statistics.totalSize" -> "1") /* totalSize always required */) + hiveClient.alterTable(newTable) + + validateColStats(tableName, statsWithHgms) + } + } + test("serialization and deserialization of histograms to/from hive metastore") { import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index ae804ce7c7b0..ab73f668c6ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -46,13 +46,15 @@ private[client] object HiveClientBuilder { def buildClient( version: String, hadoopConf: Configuration, - extraConf: Map[String, String] = Map.empty): HiveClient = { + extraConf: Map[String, String] = Map.empty, + sharesHadoopClasses: Boolean = true): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = new SparkConf(), hadoopConf = hadoopConf, config = buildConf(extraConf), - ivyPath = ivyPath).createClient() + ivyPath = ivyPath, + sharesHadoopClasses = sharesHadoopClasses).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index ce53acef5150..f991352b207d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -67,6 +67,7 @@ class HiveClientSuite(version: String) } override def beforeAll() { + super.beforeAll() client = init(true) } @@ -201,6 +202,10 @@ class HiveClientSuite(version: String) day1 :: day2 :: Nil) } + test("create client with sharesHadoopClasses = false") { + buildClient(new Configuration(), sharesHadoopClasses = false) + } + private def testMetastorePartitionFiltering( filterString: String, expectedDs: Seq[Int], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala index 2e7dfde8b2fa..30592a3f8542 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -22,5 +22,6 @@ import scala.collection.immutable.IndexedSeq import org.apache.spark.SparkFunSuite private[client] trait HiveClientVersions { - protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + protected val versions = + IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index 951ebfad4590..e5963d03f6b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -25,18 +25,24 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.HiveUtils private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + override protected val enableAutoThreadAudit = false protected var client: HiveClient = null - protected def buildClient(hadoopConf: Configuration): HiveClient = { + protected def buildClient( + hadoopConf: Configuration, + sharesHadoopClasses: Boolean = true): HiveClient = { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - HiveClientBuilder - .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) + HiveClientBuilder.buildClient( + version, + hadoopConf, + HiveUtils.formatTimeVarsForHiveClient(hadoopConf), + sharesHadoopClasses = sharesHadoopClasses) } override def suiteName: String = s"${super.suiteName}($version)" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 94473a08dd31..dc96ec416afd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -50,6 +51,8 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { + override protected val enableAutoThreadAudit = false + import HiveClientBuilder.buildClient /** @@ -108,7 +111,8 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + private val versions = + Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") private var client: HiveClient = null @@ -123,15 +127,15 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { @@ -163,6 +167,15 @@ class VersionsSuite extends SparkFunSuite with Logging { client.createDatabase(tempDB, ignoreIfExists = true) } + test(s"$version: createDatabase with null description") { + withTempDir { tmpDir => + val dbWithNullDesc = + CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) + client.createDatabase(dbWithNullDesc, ignoreIfExists = true) + assert(client.getDatabase("dbWithNullDesc").description == "") + } + } + test(s"$version: setCurrentDatabase") { client.setCurrentDatabase("default") } @@ -411,15 +424,18 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") + val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - val partition = CatalogTablePartition(spec, storage) + val partition = CatalogTablePartition(spec, storage, parameters) client.alterPartitions("default", "src_part", Seq(partition)) assert(client.getPartition("default", "src_part", spec) .storage.locationUri == Some(newLocation)) + assert(client.getPartition("default", "src_part", spec) + .parameters.get(StatsSetupConst.TOTAL_SIZE) == Some("0")) } test(s"$version: dropPartitions") { @@ -622,6 +638,46 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: CREATE Partitioned TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql( + """ + |CREATE TABLE tbl(c1 string) + |PARTITIONED BY (ds STRING) + """.stripMargin) + versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") + + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) + val partMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty && numFiles.isEmpty) + } else { + assert(totalSize.nonEmpty && numFiles.nonEmpty) + } + + versionSpark.sql( + """ + |ALTER TABLE tbl PARTITION (ds='2') + |SET SERDEPROPERTIES ('newKey' = 'vvv') + """.stripMargin) + val newPartMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + + val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(newTotalSize.isEmpty && newNumFiles.isEmpty) + } else { + assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) + } + } + } + test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { @@ -802,7 +858,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: read avro file containing decimal") { val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val location = new File(url.getFile) + val location = new File(url.getFile).toURI.toString val tableName = "tab1" val avroSchema = @@ -842,6 +898,8 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + // skipped because it's failed in the condition on Windows + assume(!(Utils.isWindows && version == "0.12")) withTempDir { dir => val avroSchema = """ @@ -866,10 +924,10 @@ class VersionsSuite extends SparkFunSuite with Logging { val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() - val schemaPath = schemaFile.getCanonicalPath + val schemaPath = schemaFile.toURI.toString val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile).getCanonicalPath + val srcLocation = new File(url.getFile).toURI.toString val destTableName = "tab1" val srcTableName = "tab2" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cee82cda4628..272e6f51f500 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -48,6 +48,8 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} abstract class HiveComparisonTest extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { + override protected val enableAutoThreadAudit = false + /** * Path to the test datasets. We find this by looking up "hive-test-path-helper.txt" file. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38..0341c3b37891 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METAS import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -875,12 +876,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( @@ -1353,7 +1355,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1391,7 +1394,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) @@ -1460,7 +1464,7 @@ class HiveDDLSuite assert(e2.getMessage.contains(forbiddenPrefix + "foo")) val e3 = intercept[AnalysisException] { - sql(s"CREATE TABLE tbl (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") + sql(s"CREATE TABLE tbl2 (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") } assert(e3.getMessage.contains(forbiddenPrefix + "foo")) } @@ -1657,8 +1661,8 @@ class HiveDDLSuite Seq(5 -> "e").toDF("i", "j") .write.format("hive").mode("append").saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `HiveFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format `HiveFileFormat`.")) } } @@ -1708,11 +1712,12 @@ class HiveDDLSuite spark.sessionState.catalog.getTableMetadata(TableIdentifier(tblName)).schema.map(_.name) } + val provider = spark.sessionState.conf.defaultDataSourceName withTable("t", "t1", "t2", "t3", "t4", "t5", "t6") { - sql("CREATE TABLE t(a int, b int, c int, d int) USING parquet PARTITIONED BY (d, b)") + sql(s"CREATE TABLE t(a int, b int, c int, d int) USING $provider PARTITIONED BY (d, b)") assert(getTableColumns("t") == Seq("a", "c", "d", "b")) - sql("CREATE TABLE t1 USING parquet PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") + sql(s"CREATE TABLE t1 USING $provider PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") assert(getTableColumns("t1") == Seq("a", "c", "d", "b")) Seq((1, 1, 1, 1)).toDF("a", "b", "c", "d").write.partitionBy("d", "b").saveAsTable("t2") @@ -1722,7 +1727,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") + sql(s"CREATE TABLE t3 USING $provider LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } @@ -2142,6 +2147,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index dfabf1ec2a22..5d56f89c2271 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -128,40 +128,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { - withTempView("jt") { - val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() - spark.read.json(ds).createOrReplaceTempView("jt") - val outputs = sql( - s""" - |EXPLAIN EXTENDED - |CREATE TABLE t1 - |AS - |SELECT * FROM jt - """.stripMargin).collect().map(_.mkString).mkString - - val shouldContain = - "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: - "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: - "CreateHiveTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil - for (key <- shouldContain) { - assert(outputs.contains(key), s"$key doesn't exist in result") - } - - val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should contain SubqueryAlias since the query should not be optimized") - } + test("explain output of physical plan should contain proper codegen stage ID") { + checkKeywordsExist(sql( + """ + |EXPLAIN SELECT t1.id AS a, t2.id AS b FROM + |(SELECT * FROM range(3)) t1 JOIN + |(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3 + """.stripMargin), + "== Physical Plan ==", + "*(2) Project ", + "+- *(2) BroadcastHashJoin ", + " :- BroadcastExchange ", + " : +- *(1) Range ", + " +- *(2) Range " + ) } test("EXPLAIN CODEGEN command") { - checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), - "WholeStageCodegen", - "Generated code:", - "/* 001 */ public Object generate(Object[] references) {", - "/* 002 */ return new GeneratedIterator(references);", - "/* 003 */ }" - ) + // the generated class name in this test should stay in sync with + // org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName() + for ((useIdInClassName, expectedClassName) <- Seq( + ("true", "GeneratedIteratorForCodegenStage1"), + ("false", "GeneratedIterator"))) { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) { + checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + s"/* 002 */ return new $expectedClassName(references);", + "/* 003 */ }" + ) + } + } checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"), "== Physical Plan ==" @@ -171,4 +170,21 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } + + test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { + val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(true) + } + assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( + s"""== Parsed Logical Plan == + |GlobalLimit 1 + |+- LocalLimit 1 + | +- AnalysisBarrier + | +- Aggregate [a#0], [a#0, count(1) AS count#0L] + | +- Project [_1#0 AS a#0, _2#0 AS b#0] + | +- LocalRelation [_1#0, _2#0] + |""".stripMargin)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 97e4c2b6b2db..5e6e114fc3fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -67,20 +67,22 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName withUserDefinedFunction(tempFunctionName -> true) { sql(s"CREATE TEMPORARY FUNCTION $tempFunctionName AS '$functionClass'") - withView("view1", "tempView1") { - withTable("tab1") { - (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") + withView("view1") { + withTempView("tempView1") { + withTable("tab1") { + (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") - // temporary view - sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") - checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) + // temporary view + sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") + checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) - // permanent view - val e = intercept[AnalysisException] { - sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + - s"a temporary function `$tempFunctionName`")) + // permanent view + val e = intercept[AnalysisException] { + sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + + s"a temporary function `$tempFunctionName`")) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 1c9f00141ae1..d7752e987cb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -100,6 +100,25 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS textfile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS sequencefile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.SequenceFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } } test("create hive serde table with new syntax - basic") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 8986fb58c646..7402c9626873 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -49,8 +49,12 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("built-in Hive UDAF") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index bc828877e35e..eaedac1fa95d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -74,7 +74,11 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with Before } override def afterAll(): Unit = { - catalog = null + try { + catalog = null + } finally { + super.afterAll() + } } test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 9eaf44c043c7..8dbcd24cd78d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -47,7 +47,11 @@ class ObjectHashAggregateSuite } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("typed_count without grouping keys") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae94584..828c18a770c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,75 +461,74 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } test("CTAS with default fileformat") { val table = "ctas1" val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" - withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { - withSQLConf("hive.default.fileformat" -> "textfile") { + Seq("orc", "parquet").foreach { dataSourceFormat => + withSQLConf( + SQLConf.CONVERT_CTAS.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> dataSourceFormat, + "hive.default.fileformat" -> "textfile") { withTable(table) { sql(ctas) - // We should use parquet here as that is the default datasource fileformat. The default - // datasource file format is controlled by `spark.sql.sources.default` configuration. + // The default datasource file format is controlled by `spark.sql.sources.default`. // This testcase verifies that setting `hive.default.fileformat` has no impact on // the target table's fileformat in case of CTAS. - assert(sessionState.conf.defaultDataSourceName === "parquet") - checkRelation(tableName = table, isDataSourceTable = true, format = "parquet") + checkRelation(tableName = table, isDataSourceTable = true, format = dataSourceFormat) } } - withSQLConf("spark.sql.sources.default" -> "orc") { - withTable(table) { - sql(ctas) - checkRelation(tableName = table, isDataSourceTable = true, format = "orc") - } - } } } @@ -539,30 +538,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } @@ -1189,35 +1198,37 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + } + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) + checkAnswer(spark.table("`db.t`"), df) } - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) } } @@ -2088,7 +2099,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( @@ -2146,11 +2158,34 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"Writing empty datasets should not fail - $format") { - withTempDir { dir => - Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } } } } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 92b2f069cacd..597b0f56a55e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -208,4 +208,14 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "false") { + withTable("spark_23340") { + sql("CREATE TABLE spark_23340(a array, b array) STORED AS ORC") + sql("INSERT INTO spark_23340 VALUES (array(), array())") + checkAnswer(spark.table("spark_23340"), Seq(Row(Array.empty[Float], Array.empty[Double]))) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 17b7d8cfe127..d556a030e218 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -62,6 +64,33 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { """.stripMargin) } + test("SPARK-22972: hive orc source") { + val tableName = "normal_orc_as_source_hive" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + val tableMetadata = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(tableName)) + assert(tableMetadata.storage.inputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(tableMetadata.storage.outputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(tableMetadata.storage.serde == + Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + } + } + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { val location = Utils.createTempDir() val uri = location.toURI diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index f87162f94c01..3b82a6c458ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ + override protected val enableAutoThreadAudit = false override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName - // ORC does not play well with NullType and UDT. + // ORC does not play well with NullType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false case _: CalendarIntervalType => false - case _: UserDefinedType[_] => false case _ => true } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala new file mode 100644 index 000000000000..bf6efa7c4c08 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure ORC read performance. + * + * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. + */ +// scalastyle:off line.size.limit +object OrcReadBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("OrcReadBenchmark") + .config(conf) + .getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName + + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val dirORC = dir.getCanonicalPath + + if (partition.isDefined) { + df.write.partitionBy(partition.get).orc(dirORC) + } else { + df.write.orc(dirORC) + } + + spark.read.format(NATIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("nativeOrcTable") + spark.read.format(HIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("hiveOrcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1135 / 1171 13.9 72.2 1.0X + Native ORC Vectorized 152 / 163 103.4 9.7 7.5X + Native ORC Vectorized with copy 149 / 162 105.4 9.5 7.6X + Hive built-in ORC 1380 / 1384 11.4 87.7 0.8X + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1182 / 1244 13.3 75.2 1.0X + Native ORC Vectorized 145 / 156 108.7 9.2 8.2X + Native ORC Vectorized with copy 148 / 158 106.4 9.4 8.0X + Hive built-in ORC 1591 / 1636 9.9 101.2 0.7X + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1271 / 1271 12.4 80.8 1.0X + Native ORC Vectorized 206 / 212 76.3 13.1 6.2X + Native ORC Vectorized with copy 200 / 213 78.8 12.7 6.4X + Hive built-in ORC 1776 / 1787 8.9 112.9 0.7X + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1344 / 1355 11.7 85.4 1.0X + Native ORC Vectorized 258 / 268 61.0 16.4 5.2X + Native ORC Vectorized with copy 252 / 257 62.4 16.0 5.3X + Hive built-in ORC 1818 / 1823 8.7 115.6 0.7X + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1333 / 1352 11.8 84.8 1.0X + Native ORC Vectorized 310 / 324 50.7 19.7 4.3X + Native ORC Vectorized with copy 312 / 320 50.4 19.9 4.3X + Hive built-in ORC 1904 / 1918 8.3 121.0 0.7X + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1408 / 1585 11.2 89.5 1.0X + Native ORC Vectorized 359 / 368 43.8 22.8 3.9X + Native ORC Vectorized with copy 364 / 371 43.2 23.2 3.9X + Hive built-in ORC 1881 / 1954 8.4 119.6 0.7X + */ + benchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2566 / 2592 4.1 244.7 1.0X + Native ORC Vectorized 1098 / 1113 9.6 104.7 2.3X + Native ORC Vectorized with copy 1527 / 1593 6.9 145.6 1.7X + Hive built-in ORC 3561 / 3705 2.9 339.6 0.7X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Data column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Data column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Data column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Partition column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Partition column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Partition column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Partition column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Both columns - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Both columns - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Both column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Both columns - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Data only - Native ORC MR 1447 / 1457 10.9 92.0 1.0X + Data only - Native ORC Vectorized 256 / 266 61.4 16.3 5.6X + Data only - Native ORC Vectorized with copy 263 / 273 59.8 16.7 5.5X + Data only - Hive built-in ORC 1960 / 1988 8.0 124.6 0.7X + Partition only - Native ORC MR 1039 / 1043 15.1 66.0 1.4X + Partition only - Native ORC Vectorized 48 / 53 326.6 3.1 30.1X + Partition only - Native ORC Vectorized with copy 48 / 53 328.4 3.0 30.2X + Partition only - Hive built-in ORC 1234 / 1242 12.7 78.4 1.2X + Both columns - Native ORC MR 1465 / 1475 10.7 93.1 1.0X + Both columns - Native ORC Vectorized 292 / 301 53.9 18.6 5.0X + Both column - Native ORC Vectorized with copy 348 / 354 45.1 22.2 4.2X + Both columns - Hive built-in ORC 2051 / 2060 7.7 130.4 0.7X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1271 / 1278 8.3 121.2 1.0X + Native ORC Vectorized 200 / 212 52.4 19.1 6.4X + Native ORC Vectorized with copy 342 / 347 30.7 32.6 3.7X + Hive built-in ORC 1874 / 2105 5.6 178.7 0.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + val benchmark = new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2394 / 2886 4.4 228.3 1.0X + Native ORC Vectorized 699 / 729 15.0 66.7 3.4X + Native ORC Vectorized with copy 959 / 1025 10.9 91.5 2.5X + Hive built-in ORC 3899 / 3901 2.7 371.9 0.6X + + String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2234 / 2255 4.7 213.1 1.0X + Native ORC Vectorized 854 / 869 12.3 81.4 2.6X + Native ORC Vectorized with copy 1099 / 1128 9.5 104.8 2.0X + Hive built-in ORC 2767 / 2793 3.8 263.9 0.8X + + String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1166 / 1202 9.0 111.2 1.0X + Native ORC Vectorized 338 / 345 31.1 32.2 3.5X + Native ORC Vectorized with copy 418 / 428 25.1 39.9 2.8X + Hive built-in ORC 1730 / 1761 6.1 164.9 0.7X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1050 / 1053 1.0 1001.1 1.0X + Native ORC Vectorized 95 / 101 11.0 90.9 11.0X + Native ORC Vectorized with copy 95 / 102 11.0 90.9 11.0X + Hive built-in ORC 348 / 358 3.0 331.8 3.0X + + Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2099 / 2108 0.5 2002.1 1.0X + Native ORC Vectorized 179 / 187 5.8 171.1 11.7X + Native ORC Vectorized with copy 176 / 188 6.0 167.6 11.9X + Hive built-in ORC 562 / 581 1.9 535.9 3.7X + + Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 3221 / 3246 0.3 3071.4 1.0X + Native ORC Vectorized 312 / 322 3.4 298.0 10.3X + Native ORC Vectorized with copy 306 / 320 3.4 291.6 10.5X + Hive built-in ORC 815 / 824 1.3 777.3 4.0X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + repeatedStringScanBenchmark(1024 * 1024 * 10) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) + } +} +// scalastyle:on line.size.limit diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 740e0837350c..2327d83a1b4f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -180,15 +180,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet", - "jt", - "jt_array", - "test_parquet") - super.afterAll() + try { + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + } finally { + super.afterAll() + } } test(s"conversion is working") { @@ -931,11 +934,15 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with } override protected def afterAll(): Unit = { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() + try { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } finally { + super.afterAll() + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index df7988f542b7..d50bf0b8fd60 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -26,10 +26,11 @@ import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + override protected val enableAutoThreadAudit = false protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 80aff446bc24..53397991e59d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -335,16 +335,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") - intercept[AnalysisException] { + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") + val msg = intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") - } + }.getMessage + assert(msg.contains("Table `t` already exists")) } } test("saveAsTable()/load() - non-partitioned table - Ignore") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } diff --git a/streaming/pom.xml b/streaming/pom.xml index fea882ad1123..4497e53b6598 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a502743..3703a87cdb9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -57,9 +57,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", - "spark.yarn.credentials.file", - "spark.yarn.credentials.renewalTime", - "spark.yarn.credentials.updateTime", "spark.ui.filters", "spark.mesos.driver.frameworkId") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 27644a645727..5d38c56aa587 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -159,7 +159,9 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) - trackerEndpoint.askSync[Boolean](AddBlock(blockInfo)) + if (!trackerEndpoint.askSync[Boolean](AddBlock(blockInfo))) { + throw new SparkException("Failed to add block to receiver tracker.") + } logDebug(s"Reported block $blockId") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 5d9a8ac0d929..dacff69d55dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker( getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo } - // Insert the recovered block-to-batch allocations and clear the queue of received blocks - // (when the blocks were originally allocated to the batch, the queue must have been cleared). + // Insert the recovered block-to-batch allocations and removes them from queue of + // received blocks. def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") - streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + allocatedBlocks.streamIdToAllocatedBlocks.foreach { + case (streamId, allocatedBlocksInStream) => + getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet) + } timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } @@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { + private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { logTrace(s"Writing record: $record") try { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6f130c803f31..c74ca1918a81 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -521,7 +521,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (active) { context.reply(addBlock(receivedBlockInfo)) } else { - throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + context.sendFailure( + new IllegalStateException("ReceiverTracker RpcEndpoint already shut down.")) } } }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 7abafd6ba790..3a176f64cdd6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -490,15 +490,40 @@ private[ui] class StreamingPage(parent: StreamingTab) sortBy(_.batchTime.milliseconds).reverse val activeBatchesContent = { -

      Active Batches ({runningBatches.size + waitingBatches.size})

      ++ - new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq +
      +
      + +

      + + Active Batches ({runningBatches.size + waitingBatches.size}) +

      +
      +
      + {new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq} +
      +
      +
      } val completedBatchesContent = { -

      - Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) -

      ++ - new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq +
      +
      + +

      + + Completed Batches (last {completedBatches.size} + out of {listener.numTotalCompletedBatches}) +

      +
      +
      + {new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq} +
      +
      +
      } activeBatchesContent ++ completedBatchesContent diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ee2fd45a7e85..19b621f11759 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -97,7 +97,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => val batchDurationMillis = batchDuration.milliseconds // Setup the stream computation - val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + val checkpointDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString logDebug(s"Using checkpoint directory $checkpointDir") val ssc = createContextForCheckpointOperation(batchDuration) require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 3b662ec1833a..06c0c2aa97ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -39,7 +39,7 @@ class MapWithStateSuite extends SparkFunSuite before { StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir("checkpoint") + checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") } after { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 107c3f5dcc08..4fa236bd3966 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult -import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _} import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } + test("recovery with write ahead logs should remove only allocated blocks from received queue") { + val manualClock = new ManualClock + val batchTime = manualClock.getTimeMillis() + + val tracker1 = createTracker(clock = manualClock) + tracker1.isWriteAheadLogEnabled should be (true) + + val allocatedBlockInfos = generateBlockInfos() + val unallocatedBlockInfos = generateBlockInfos() + val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos + receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) } + val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos)) + tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + tracker1.stop() + + val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) + tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks + tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 145c48e5a9a7..fc6218a33f74 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -105,13 +105,13 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { assert(executor.errors.head.eq(exception)) // Verify restarting actually stops and starts the receiver - receiver.restart("restarting", null, 100) - eventually(timeout(50 millis), interval(10 millis)) { + receiver.restart("restarting", null, 600) + eventually(timeout(300 millis), interval(10 millis)) { // receiver will be stopped async assert(receiver.isStopped) assert(receiver.onStopCalled) } - eventually(timeout(1000 millis), interval(100 millis)) { + eventually(timeout(1000 millis), interval(10 millis)) { // receiver will be started async assert(receiver.onStartCalled) assert(executor.isReceiverStarted) diff --git a/tools/pom.xml b/tools/pom.xml index 37427e8da62d..242219e29f50 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml