From 86db9b2d7d3c3c8e6c27b05cb36318dba9138e78 Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Sat, 23 Dec 2017 08:13:34 -0600 Subject: [PATCH 001/100] [SPARK-22833][IMPROVEMENT] in SparkHive Scala Examples ## What changes were proposed in this pull request? SparkHive Scala Examples Improvement made: * Writing DataFrame / DataSet to Hive Managed , Hive External table using different storage format. * Implementation of Partition, Reparition, Coalesce with appropriate example. ## How was this patch tested? * Patch has been tested manually and by running ./dev/run-tests. Author: chetkhatri Closes #20018 from chetkhatri/scala-sparkhive-examples. --- .../examples/sql/hive/SparkHiveExample.scala | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index e5f75d53edc8..51df5dd8e360 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -19,8 +19,7 @@ package org.apache.spark.examples.sql.hive // $example on:spark_hive$ import java.io.File -import org.apache.spark.sql.Row -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SaveMode, SparkSession} // $example off:spark_hive$ object SparkHiveExample { @@ -102,8 +101,41 @@ object SparkHiveExample { // | 4| val_4| 4| val_4| // | 5| val_5| 5| val_5| // ... - // $example off:spark_hive$ + // Create Hive managed table with Parquet + sql("CREATE TABLE records(key int, value string) STORED AS PARQUET") + // Save DataFrame to Hive managed table as Parquet format + val hiveTableDF = sql("SELECT * FROM records") + hiveTableDF.write.mode(SaveMode.Overwrite).saveAsTable("database_name.records") + // Create External Hive table with Parquet + sql("CREATE EXTERNAL TABLE records(key int, value string) " + + "STORED AS PARQUET LOCATION '/user/hive/warehouse/'") + // to make Hive Parquet format compatible with Spark Parquet format + spark.sqlContext.setConf("spark.sql.parquet.writeLegacyFormat", "true") + + // Multiple Parquet files could be created accordingly to volume of data under directory given. + val hiveExternalTableLocation = "/user/hive/warehouse/database_name.db/records" + + // Save DataFrame to Hive External table as compatible Parquet format + hiveTableDF.write.mode(SaveMode.Overwrite).parquet(hiveExternalTableLocation) + + // Turn on flag for Dynamic Partitioning + spark.sqlContext.setConf("hive.exec.dynamic.partition", "true") + spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict") + + // You can create partitions in Hive table, so downstream queries run much faster. + hiveTableDF.write.mode(SaveMode.Overwrite).partitionBy("key") + .parquet(hiveExternalTableLocation) + + // Reduce number of files for each partition by repartition + hiveTableDF.repartition($"key").write.mode(SaveMode.Overwrite) + .partitionBy("key").parquet(hiveExternalTableLocation) + + // Control the number of files in each partition by coalesce + hiveTableDF.coalesce(10).write.mode(SaveMode.Overwrite) + .partitionBy("key").parquet(hiveExternalTableLocation) + // $example off:spark_hive$ + spark.stop() } } From ea2642eb0ebcf02e8fba727a82c140c9f2284725 Mon Sep 17 00:00:00 2001 From: CNRui <13266776177@163.com> Date: Sat, 23 Dec 2017 08:18:08 -0600 Subject: [PATCH 002/100] [SPARK-20694][EXAMPLES] Update SQLDataSourceExample.scala ## What changes were proposed in this pull request? Create table using the right DataFrame. peopleDF->usersDF peopleDF: +----+-------+ | age| name| +----+-------+ usersDF: +------+--------------+----------------+ | name|favorite_color|favorite_numbers| +------+--------------+----------------+ ## How was this patch tested? Manually tested. Author: CNRui <13266776177@163.com> Closes #20052 from CNRui/patch-2. --- .../apache/spark/examples/sql/SQLDataSourceExample.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index f9477969a4bb..7d83aacb1154 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -67,15 +67,15 @@ object SQLDataSourceExample { usersDF.write.partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet") // $example off:write_partitioning$ // $example on:write_partition_and_bucket$ - peopleDF + usersDF .write .partitionBy("favorite_color") .bucketBy(42, "name") - .saveAsTable("people_partitioned_bucketed") + .saveAsTable("users_partitioned_bucketed") // $example off:write_partition_and_bucket$ spark.sql("DROP TABLE IF EXISTS people_bucketed") - spark.sql("DROP TABLE IF EXISTS people_partitioned_bucketed") + spark.sql("DROP TABLE IF EXISTS users_partitioned_bucketed") } private def runBasicParquetExample(spark: SparkSession): Unit = { From f6084a88f0fe69111df8a016bc81c9884d3d3402 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 24 Dec 2017 01:16:12 +0900 Subject: [PATCH 003/100] [HOTFIX] Fix Scala style checks ## What changes were proposed in this pull request? This PR fixes a style that broke the build. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20065 from HyukjinKwon/minor-style. --- .../org/apache/spark/examples/sql/hive/SparkHiveExample.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index 51df5dd8e360..b193bd595127 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -135,7 +135,7 @@ object SparkHiveExample { hiveTableDF.coalesce(10).write.mode(SaveMode.Overwrite) .partitionBy("key").parquet(hiveExternalTableLocation) // $example off:spark_hive$ - + spark.stop() } } From aeb45df668a97a2d48cfd4079ed62601390979ba Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 24 Dec 2017 01:18:11 +0900 Subject: [PATCH 004/100] [SPARK-22844][R] Adds date_trunc in R API ## What changes were proposed in this pull request? This PR adds `date_trunc` in R API as below: ```r > df <- createDataFrame(list(list(a = as.POSIXlt("2012-12-13 12:34:00")))) > head(select(df, date_trunc("hour", df$a))) date_trunc(hour, a) 1 2012-12-13 12:00:00 ``` ## How was this patch tested? Unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon Closes #20031 from HyukjinKwon/r-datetrunc. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 34 +++++++++++++++++++++++---- R/pkg/R/generics.R | 5 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 3 +++ 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 57838f52eac3..dce64e1e607c 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -230,6 +230,7 @@ exportMethods("%<=>%", "date_add", "date_format", "date_sub", + "date_trunc", "datediff", "dayofmonth", "dayofweek", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 237ef061e807..3a96f94e269f 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -40,10 +40,17 @@ NULL #' #' @param x Column to compute on. In \code{window}, it must be a time Column of #' \code{TimestampType}. -#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse -#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string -#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for -#' truncate by year, or "month", "mon", "mm" for truncate by month. +#' @param format The format for the given dates or timestamps in Column \code{x}. See the +#' format used in the following methods: +#' \itemize{ +#' \item \code{to_date} and \code{to_timestamp}: it is the string to use to parse +#' Column \code{x} to DateType or TimestampType. +#' \item \code{trunc}: it is the string to use to specify the truncation method. +#' For example, "year", "yyyy", "yy" for truncate by year, or "month", "mon", +#' "mm" for truncate by month. +#' \item \code{date_trunc}: it is similar with \code{trunc}'s but additionally +#' supports "day", "dd", "second", "minute", "hour", "week" and "quarter". +#' } #' @param ... additional argument(s). #' @name column_datetime_functions #' @rdname column_datetime_functions @@ -3478,3 +3485,22 @@ setMethod("trunc", x@jc, as.character(format)) column(jc) }) + +#' @details +#' \code{date_trunc}: Returns timestamp truncated to the unit specified by the format. +#' +#' @rdname column_datetime_functions +#' @aliases date_trunc date_trunc,character,Column-method +#' @export +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, date_trunc("hour", df$time), date_trunc("minute", df$time), +#' date_trunc("week", df$time), date_trunc("quarter", df$time)))} +#' @note date_trunc since 2.3.0 +setMethod("date_trunc", + signature(format = "character", x = "Column"), + function(format, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_trunc", format, x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8fcf269087c7..5ddaa669f920 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1043,6 +1043,11 @@ setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) + #' @rdname column_datetime_functions #' @export #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index d87f5d270573..6cc0188dae95 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1418,6 +1418,8 @@ test_that("column functions", { c22 <- not(c) c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") + c24 <- date_trunc("hour", c) + date_trunc("minute", c) + date_trunc("week", c) + + date_trunc("quarter", c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1729,6 +1731,7 @@ test_that("date functions on a DataFrame", { expect_gt(collect(select(df2, unix_timestamp()))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + expect_equal(collect(select(df2, month(date_trunc("yyyy", df2$b))))[, 1], c(1, 1)) l3 <- list(list(a = 1000), list(a = -1000)) df3 <- createDataFrame(l3) From 1219d7a4343e837749c56492d382b3c814b97271 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 23 Dec 2017 10:27:14 -0800 Subject: [PATCH 005/100] [SPARK-22889][SPARKR] Set overwrite=T when install SparkR in tests ## What changes were proposed in this pull request? Since all CRAN checks go through the same machine, if there is an older partial download or partial install of Spark left behind the tests fail. This PR overwrites the install files when running tests. This shouldn't affect Jenkins as `SPARK_HOME` is set when running Jenkins tests. ## How was this patch tested? Test manually by running `R CMD check --as-cran` Author: Shivaram Venkataraman Closes #20060 from shivaram/sparkr-overwrite-cran. --- R/pkg/tests/run-all.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 63812ba70bb5..94d75188fb94 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -27,7 +27,10 @@ if (.Platform$OS.type == "windows") { # Setup global test environment # Install Spark first to set SPARK_HOME -install.spark() + +# NOTE(shivaram): We set overwrite to handle any old tar.gz files or directories left behind on +# CRAN machines. For Jenkins we should already have SPARK_HOME set. +install.spark(overwrite = TRUE) sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") From 0bf1a74a773c79e66a67055298a36af477b9e21a Mon Sep 17 00:00:00 2001 From: sujithjay Date: Sun, 24 Dec 2017 11:14:30 -0800 Subject: [PATCH 006/100] [SPARK-22465][CORE] Add a safety-check to RDD defaultPartitioner ## What changes were proposed in this pull request? In choosing a Partitioner to use for a cogroup-like operation between a number of RDDs, the default behaviour was if some of the RDDs already have a partitioner, we choose the one amongst them with the maximum number of partitions. This behaviour, in some cases, could hit the 2G limit (SPARK-6235). To illustrate one such scenario, consider two RDDs: rDD1: with smaller data and smaller number of partitions, alongwith a Partitioner. rDD2: with much larger data and a larger number of partitions, without a Partitioner. The cogroup of these two RDDs could hit the 2G limit, as a larger amount of data is shuffled into a smaller number of partitions. This PR introduces a safety-check wherein the Partitioner is chosen only if either of the following conditions are met: 1. if the number of partitions of the RDD associated with the Partitioner is greater than or equal to the max number of upstream partitions; or 2. if the number of partitions of the RDD associated with the Partitioner is less than and within a single order of magnitude of the max number of upstream partitions. ## How was this patch tested? Unit tests in PartitioningSuite and PairRDDFunctionsSuite Author: sujithjay Closes #20002 from sujithjay/SPARK-22465. --- .../scala/org/apache/spark/Partitioner.scala | 32 +++++++++++++++++-- .../org/apache/spark/PartitioningSuite.scala | 27 ++++++++++++++++ .../spark/rdd/PairRDDFunctionsSuite.scala | 22 +++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index debbd8d7c26c..437bbaae1968 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.math.log10 import scala.reflect.ClassTag import scala.util.hashing.byteswap32 @@ -42,7 +43,9 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, choose that one. + * If any of the RDDs already has a partitioner, and the number of partitions of the + * partitioner is either greater than or is less than and within a single order of + * magnitude of the max number of upstream partitions, choose that one. * * Otherwise, we use a default HashPartitioner. For the number of partitions, if * spark.default.parallelism is set, then we'll use the value from SparkContext @@ -57,8 +60,15 @@ object Partitioner { def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val rdds = (Seq(rdd) ++ others) val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) - if (hasPartitioner.nonEmpty) { - hasPartitioner.maxBy(_.partitions.length).partitioner.get + + val hasMaxPartitioner: Option[RDD[_]] = if (hasPartitioner.nonEmpty) { + Some(hasPartitioner.maxBy(_.partitions.length)) + } else { + None + } + + if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + hasMaxPartitioner.get.partitioner.get } else { if (rdd.context.conf.contains("spark.default.parallelism")) { new HashPartitioner(rdd.context.defaultParallelism) @@ -67,6 +77,22 @@ object Partitioner { } } } + + /** + * Returns true if the number of partitions of the RDD is either greater + * than or is less than and within a single order of magnitude of the + * max number of upstream partitions; + * otherwise, returns false + */ + private def isEligiblePartitioner( + hasMaxPartitioner: Option[RDD[_]], + rdds: Seq[RDD[_]]): Boolean = { + if (hasMaxPartitioner.isEmpty) { + return false + } + val maxPartitions = rdds.map(_.partitions.length).max + log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + } } /** diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index dfe4c25670ce..155ca17db726 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -259,6 +259,33 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva val partitioner = new RangePartitioner(22, rdd) assert(partitioner.numPartitions === 3) } + + test("defaultPartitioner") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc + .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc + .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc + .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + + assert(partitioner1.numPartitions == rdd1.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 65d35264dc10..a39e0469272f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -310,6 +310,28 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.size > 0) } + // See SPARK-22465 + test("cogroup between multiple RDD " + + "with an order of magnitude difference in number of partitions") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc + .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd1.getNumPartitions) + } + + // See SPARK-22465 + test("cogroup between multiple RDD" + + " with number of partitions similar in order of magnitude") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc + .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) From fba03133d1369ce8cfebd6ae68b987a95f1b7ea1 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sun, 24 Dec 2017 22:57:53 -0800 Subject: [PATCH 007/100] [SPARK-22707][ML] Optimize CrossValidator memory occupation by models in fitting ## What changes were proposed in this pull request? Via some test I found CrossValidator still exists memory issue, it will still occupy `O(n*sizeof(model))` memory for holding models when fitting, if well optimized, it should be `O(parallelism*sizeof(model))` This is because modelFutures will hold the reference to model object after future is complete (we can use `future.value.get.get` to fetch it), and the `Future.sequence` and the `modelFutures` array holds references to each model future. So all model object are keep referenced. So it will still occupy `O(n*sizeof(model))` memory. I fix this by merging the `modelFuture` and `foldMetricFuture` together, and use `atomicInteger` to statistic complete fitting tasks and when all done, trigger `trainingDataset.unpersist`. I ever commented this issue on the old PR [SPARK-19357] https://github.com/apache/spark/pull/16774#pullrequestreview-53674264 unfortunately, at that time I do not realize that the issue still exists, but now I confirm it and create this PR to fix it. ## Discussion I give 3 approaches which we can compare, after discussion I realized none of them is ideal, we have to make a trade-off. **After discussion with jkbradley , choose approach 3** ### Approach 1 ~~The approach proposed by MrBago at~~ https://github.com/apache/spark/pull/19904#discussion_r156751569 ~~This approach resolve the model objects referenced issue, allow the model objects to be GCed in time. **BUT, in some cases, it still do not resolve the O(N) model memory occupation issue**. Let me use an extreme case to describe it:~~ ~~suppose we set `parallelism = 1`, and there're 100 paramMaps. So we have 100 fitting & evaluation tasks. In this approach, because of `parallelism = 1`, the code have to wait 100 fitting tasks complete, **(at this time the memory occupation by models already reach 100 * sizeof(model) )** and then it will unpersist training dataset and then do 100 evaluation tasks.~~ ### Approach 2 ~~This approach is my PR old version code~~ https://github.com/apache/spark/pull/19904/commits/2cc7c28f385009570536690d686f2843485942b2 ~~This approach can make sure at any case, the peak memory occupation by models to be `O(numParallelism * sizeof(model))`, but, it exists an issue that, in some extreme case, the "unpersist training dataset" will be delayed until most of the evaluation tasks complete. Suppose the case `parallelism = 1`, and there're 100 fitting & evaluation tasks, each fitting&evaluation task have to be executed one by one, so only after the first 99 fitting&evaluation tasks and the 100th fitting task complete, the "unpersist training dataset" will be triggered.~~ ### Approach 3 After I compared approach 1 and approach 2, I realized that, in the case which parallelism is low but there're many fitting & evaluation tasks, we cannot achieve both of the following two goals: - Make the peak memory occupation by models(driver-side) to be O(parallelism * sizeof(model)) - unpersist training dataset before most of the evaluation tasks started. So I vote for a simpler approach, move the unpersist training dataset to the end (Does this really matters ?) Because the goal 1 is more important, we must make sure the peak memory occupation by models (driver-side) to be O(parallelism * sizeof(model)), otherwise it will bring high risk of OOM. Like following code: ``` val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] //...other minor codes val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric metricformodeltrainedwithparamMap.") metric } (executionContext) } val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) trainingDataset.unpersist() // <------- unpersist at the end validationDataset.unpersist() ``` ## How was this patch tested? N/A Author: WeichenXu Closes #19904 from WeichenXu123/fix_cross_validator_memory_issue. --- .../apache/spark/ml/tuning/CrossValidator.scala | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 1682ca91bf83..0130b3e255f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -147,24 +147,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] - if (collectSubModelsParam) { subModels.get(splitIndex)(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -174,6 +162,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Wait for metrics to be calculated before unpersisting validation dataset val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + trainingDataset.unpersist() validationDataset.unpersist() foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits From 33ae2437ba4e634b510b0f96e914ad1ef4ccafd8 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 25 Dec 2017 01:14:09 -0800 Subject: [PATCH 008/100] [SPARK-22893][SQL] Unified the data type mismatch message ## What changes were proposed in this pull request? We should use `dataType.simpleString` to unified the data type mismatch message: Before: ``` spark-sql> select cast(1 as binary); Error in query: cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7; ``` After: ``` park-sql> select cast(1 as binary); Error in query: cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7; ``` ## How was this patch tested? Exist test. Author: Yuming Wang Closes #20064 from wangyum/SPARK-22893. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../aggregate/ApproximatePercentile.scala | 4 +- .../expressions/conditionalExpressions.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 10 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../expressions/windowExpressions.scala | 8 +- .../native/binaryComparison.sql.out | 48 +-- .../typeCoercion/native/inConversion.sql.out | 280 +++++++++--------- .../native/windowFrameCoercion.sql.out | 8 +- .../sql-tests/results/window.sql.out | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 12 files changed, 189 insertions(+), 186 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5279d4127896..274d8813f16d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -181,7 +181,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType} to $dataType") + s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 7facb9dad9a7..149ac265e6ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -132,7 +132,7 @@ case class ApproximatePercentile( case TimestampType => value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") + throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") } buffer.add(doubleValue) } @@ -157,7 +157,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") + throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") } if (result.length == 0) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 142dfb02be0a..b444c3a7be92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -40,7 +40,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def checkInputDataTypes(): TypeCheckResult = { if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( - s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + "type of predicate expression in If should be boolean, " + + s"not ${predicate.dataType.simpleString}") } else if (!trueValue.dataType.sameType(falseValue.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 69af7a250a5a..4f4d49166e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -155,8 +155,8 @@ case class Stack(children: Seq[Expression]) extends Generator { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " + - s"Argument $i (${children(i).dataType})") + s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " + + s"Argument $i (${children(i).dataType.simpleString})") } } TypeCheckResult.TypeCheckSuccess @@ -249,7 +249,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"input to function explode should be array or map type, not ${child.dataType}") + "input to function explode should be array or map type, " + + s"not ${child.dataType.simpleString}") } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -378,7 +379,8 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should be array of struct type, not ${child.dataType}") + s"input to function $prettyName should be array of struct type, " + + s"not ${child.dataType.simpleString}") } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f4ee3d10f3f4..b469f5cb7586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -195,7 +195,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } case _ => TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType} != ${mismatchOpt.get.dataType}") + s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}") } } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 220cc4f885d7..dd13d9a3bba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -70,9 +70,9 @@ case class WindowSpecDefinition( case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && !isValidFrameType(f.valueBoundary.head.dataType) => TypeCheckFailure( - s"The data type '${orderSpec.head.dataType}' used in the order specification does " + - s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " + - "range frame.") + s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " + + "specification does not match the data type " + + s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.") case _ => TypeCheckSuccess } } @@ -251,7 +251,7 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType}' does not match " + + s"The data type of the $location bound '${e.dataType.simpleString}' does not match " + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out index fe7bde040707..2914d6015ea8 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out @@ -16,7 +16,7 @@ SELECT cast(1 as binary) = '1' FROM t struct<> -- !query 1 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 2 @@ -25,7 +25,7 @@ SELECT cast(1 as binary) > '2' FROM t struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 3 @@ -34,7 +34,7 @@ SELECT cast(1 as binary) >= '2' FROM t struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 4 @@ -43,7 +43,7 @@ SELECT cast(1 as binary) < '2' FROM t struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 5 @@ -52,7 +52,7 @@ SELECT cast(1 as binary) <= '2' FROM t struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 6 @@ -61,7 +61,7 @@ SELECT cast(1 as binary) <> '2' FROM t struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 7 @@ -70,7 +70,7 @@ SELECT cast(1 as binary) = cast(null as string) FROM t struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 8 @@ -79,7 +79,7 @@ SELECT cast(1 as binary) > cast(null as string) FROM t struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 9 @@ -88,7 +88,7 @@ SELECT cast(1 as binary) >= cast(null as string) FROM t struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 10 @@ -97,7 +97,7 @@ SELECT cast(1 as binary) < cast(null as string) FROM t struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 11 @@ -106,7 +106,7 @@ SELECT cast(1 as binary) <= cast(null as string) FROM t struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 12 @@ -115,7 +115,7 @@ SELECT cast(1 as binary) <> cast(null as string) FROM t struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 13 @@ -124,7 +124,7 @@ SELECT '1' = cast(1 as binary) FROM t struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 14 @@ -133,7 +133,7 @@ SELECT '2' > cast(1 as binary) FROM t struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 15 @@ -142,7 +142,7 @@ SELECT '2' >= cast(1 as binary) FROM t struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 16 @@ -151,7 +151,7 @@ SELECT '2' < cast(1 as binary) FROM t struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 17 @@ -160,7 +160,7 @@ SELECT '2' <= cast(1 as binary) FROM t struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 18 @@ -169,7 +169,7 @@ SELECT '2' <> cast(1 as binary) FROM t struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 19 @@ -178,7 +178,7 @@ SELECT cast(null as string) = cast(1 as binary) FROM t struct<> -- !query 19 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 20 @@ -187,7 +187,7 @@ SELECT cast(null as string) > cast(1 as binary) FROM t struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 21 @@ -196,7 +196,7 @@ SELECT cast(null as string) >= cast(1 as binary) FROM t struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 22 @@ -205,7 +205,7 @@ SELECT cast(null as string) < cast(1 as binary) FROM t struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 23 @@ -214,7 +214,7 @@ SELECT cast(null as string) <= cast(1 as binary) FROM t struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 24 @@ -223,7 +223,7 @@ SELECT cast(null as string) <> cast(1 as binary) FROM t struct<> -- !query 24 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 25 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out index bf8ddee89b79..875ccc1341ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out @@ -80,7 +80,7 @@ SELECT cast(1 as tinyint) in (cast('1' as binary)) FROM t struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: tinyint != binary; line 1 pos 26 -- !query 10 @@ -89,7 +89,7 @@ SELECT cast(1 as tinyint) in (cast(1 as boolean)) FROM t struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: tinyint != boolean; line 1 pos 26 -- !query 11 @@ -98,7 +98,7 @@ SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: tinyint != timestamp; line 1 pos 26 -- !query 12 @@ -107,7 +107,7 @@ SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: tinyint != date; line 1 pos 26 -- !query 13 @@ -180,7 +180,7 @@ SELECT cast(1 as smallint) in (cast('1' as binary)) FROM t struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: smallint != binary; line 1 pos 27 -- !query 22 @@ -189,7 +189,7 @@ SELECT cast(1 as smallint) in (cast(1 as boolean)) FROM t struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: smallint != boolean; line 1 pos 27 -- !query 23 @@ -198,7 +198,7 @@ SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: smallint != timestamp; line 1 pos 27 -- !query 24 @@ -207,7 +207,7 @@ SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 24 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: smallint != date; line 1 pos 27 -- !query 25 @@ -280,7 +280,7 @@ SELECT cast(1 as int) in (cast('1' as binary)) FROM t struct<> -- !query 33 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: int != binary; line 1 pos 22 -- !query 34 @@ -289,7 +289,7 @@ SELECT cast(1 as int) in (cast(1 as boolean)) FROM t struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: int != boolean; line 1 pos 22 -- !query 35 @@ -298,7 +298,7 @@ SELECT cast(1 as int) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: int != timestamp; line 1 pos 22 -- !query 36 @@ -307,7 +307,7 @@ SELECT cast(1 as int) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 36 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: int != date; line 1 pos 22 -- !query 37 @@ -380,7 +380,7 @@ SELECT cast(1 as bigint) in (cast('1' as binary)) FROM t struct<> -- !query 45 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: bigint != binary; line 1 pos 25 -- !query 46 @@ -389,7 +389,7 @@ SELECT cast(1 as bigint) in (cast(1 as boolean)) FROM t struct<> -- !query 46 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: bigint != boolean; line 1 pos 25 -- !query 47 @@ -398,7 +398,7 @@ SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 47 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: bigint != timestamp; line 1 pos 25 -- !query 48 @@ -407,7 +407,7 @@ SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 48 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: bigint != date; line 1 pos 25 -- !query 49 @@ -480,7 +480,7 @@ SELECT cast(1 as float) in (cast('1' as binary)) FROM t struct<> -- !query 57 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: float != binary; line 1 pos 24 -- !query 58 @@ -489,7 +489,7 @@ SELECT cast(1 as float) in (cast(1 as boolean)) FROM t struct<> -- !query 58 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: float != boolean; line 1 pos 24 -- !query 59 @@ -498,7 +498,7 @@ SELECT cast(1 as float) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 59 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: float != timestamp; line 1 pos 24 -- !query 60 @@ -507,7 +507,7 @@ SELECT cast(1 as float) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 60 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: float != date; line 1 pos 24 -- !query 61 @@ -580,7 +580,7 @@ SELECT cast(1 as double) in (cast('1' as binary)) FROM t struct<> -- !query 69 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: double != binary; line 1 pos 25 -- !query 70 @@ -589,7 +589,7 @@ SELECT cast(1 as double) in (cast(1 as boolean)) FROM t struct<> -- !query 70 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: double != boolean; line 1 pos 25 -- !query 71 @@ -598,7 +598,7 @@ SELECT cast(1 as double) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 71 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: double != timestamp; line 1 pos 25 -- !query 72 @@ -607,7 +607,7 @@ SELECT cast(1 as double) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 72 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: double != date; line 1 pos 25 -- !query 73 @@ -680,7 +680,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('1' as binary)) FROM t struct<> -- !query 81 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != binary; line 1 pos 33 -- !query 82 @@ -689,7 +689,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as boolean)) FROM t struct<> -- !query 82 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != boolean; line 1 pos 33 -- !query 83 @@ -698,7 +698,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00.0' as timestamp)) struct<> -- !query 83 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != timestamp; line 1 pos 33 -- !query 84 @@ -707,7 +707,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 84 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != date; line 1 pos 33 -- !query 85 @@ -780,7 +780,7 @@ SELECT cast(1 as string) in (cast('1' as binary)) FROM t struct<> -- !query 93 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 -- !query 94 @@ -789,7 +789,7 @@ SELECT cast(1 as string) in (cast(1 as boolean)) FROM t struct<> -- !query 94 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 -- !query 95 @@ -814,7 +814,7 @@ SELECT cast('1' as binary) in (cast(1 as tinyint)) FROM t struct<> -- !query 97 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: binary != tinyint; line 1 pos 27 -- !query 98 @@ -823,7 +823,7 @@ SELECT cast('1' as binary) in (cast(1 as smallint)) FROM t struct<> -- !query 98 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: binary != smallint; line 1 pos 27 -- !query 99 @@ -832,7 +832,7 @@ SELECT cast('1' as binary) in (cast(1 as int)) FROM t struct<> -- !query 99 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: binary != int; line 1 pos 27 -- !query 100 @@ -841,7 +841,7 @@ SELECT cast('1' as binary) in (cast(1 as bigint)) FROM t struct<> -- !query 100 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: binary != bigint; line 1 pos 27 -- !query 101 @@ -850,7 +850,7 @@ SELECT cast('1' as binary) in (cast(1 as float)) FROM t struct<> -- !query 101 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: binary != float; line 1 pos 27 -- !query 102 @@ -859,7 +859,7 @@ SELECT cast('1' as binary) in (cast(1 as double)) FROM t struct<> -- !query 102 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: binary != double; line 1 pos 27 -- !query 103 @@ -868,7 +868,7 @@ SELECT cast('1' as binary) in (cast(1 as decimal(10, 0))) FROM t struct<> -- !query 103 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: binary != decimal(10,0); line 1 pos 27 -- !query 104 @@ -877,7 +877,7 @@ SELECT cast('1' as binary) in (cast(1 as string)) FROM t struct<> -- !query 104 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 -- !query 105 @@ -894,7 +894,7 @@ SELECT cast('1' as binary) in (cast(1 as boolean)) FROM t struct<> -- !query 106 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: binary != boolean; line 1 pos 27 -- !query 107 @@ -903,7 +903,7 @@ SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM struct<> -- !query 107 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: binary != timestamp; line 1 pos 27 -- !query 108 @@ -912,7 +912,7 @@ SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 108 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: binary != date; line 1 pos 27 -- !query 109 @@ -921,7 +921,7 @@ SELECT true in (cast(1 as tinyint)) FROM t struct<> -- !query 109 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: boolean != tinyint; line 1 pos 12 -- !query 110 @@ -930,7 +930,7 @@ SELECT true in (cast(1 as smallint)) FROM t struct<> -- !query 110 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: boolean != smallint; line 1 pos 12 -- !query 111 @@ -939,7 +939,7 @@ SELECT true in (cast(1 as int)) FROM t struct<> -- !query 111 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: boolean != int; line 1 pos 12 -- !query 112 @@ -948,7 +948,7 @@ SELECT true in (cast(1 as bigint)) FROM t struct<> -- !query 112 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: boolean != bigint; line 1 pos 12 -- !query 113 @@ -957,7 +957,7 @@ SELECT true in (cast(1 as float)) FROM t struct<> -- !query 113 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: boolean != float; line 1 pos 12 -- !query 114 @@ -966,7 +966,7 @@ SELECT true in (cast(1 as double)) FROM t struct<> -- !query 114 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: boolean != double; line 1 pos 12 -- !query 115 @@ -975,7 +975,7 @@ SELECT true in (cast(1 as decimal(10, 0))) FROM t struct<> -- !query 115 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: boolean != decimal(10,0); line 1 pos 12 -- !query 116 @@ -984,7 +984,7 @@ SELECT true in (cast(1 as string)) FROM t struct<> -- !query 116 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 12 -- !query 117 @@ -993,7 +993,7 @@ SELECT true in (cast('1' as binary)) FROM t struct<> -- !query 117 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 12 +cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: boolean != binary; line 1 pos 12 -- !query 118 @@ -1010,7 +1010,7 @@ SELECT true in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 119 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 12 +cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: boolean != timestamp; line 1 pos 12 -- !query 120 @@ -1019,7 +1019,7 @@ SELECT true in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 120 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 12 +cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: boolean != date; line 1 pos 12 -- !query 121 @@ -1028,7 +1028,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as tinyint)) FROM t struct<> -- !query 121 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != tinyint; line 1 pos 50 -- !query 122 @@ -1037,7 +1037,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as smallint)) FROM struct<> -- !query 122 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != smallint; line 1 pos 50 -- !query 123 @@ -1046,7 +1046,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as int)) FROM t struct<> -- !query 123 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: timestamp != int; line 1 pos 50 -- !query 124 @@ -1055,7 +1055,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as bigint)) FROM t struct<> -- !query 124 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != bigint; line 1 pos 50 -- !query 125 @@ -1064,7 +1064,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as float)) FROM t struct<> -- !query 125 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: timestamp != float; line 1 pos 50 -- !query 126 @@ -1073,7 +1073,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as double)) FROM t struct<> -- !query 126 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: timestamp != double; line 1 pos 50 -- !query 127 @@ -1082,7 +1082,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as decimal(10, 0))) struct<> -- !query 127 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: timestamp != decimal(10,0); line 1 pos 50 -- !query 128 @@ -1099,7 +1099,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2' as binary)) FROM struct<> -- !query 129 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: timestamp != binary; line 1 pos 50 -- !query 130 @@ -1108,7 +1108,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as boolean)) FROM t struct<> -- !query 130 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: timestamp != boolean; line 1 pos 50 -- !query 131 @@ -1133,7 +1133,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as tinyint)) FROM t struct<> -- !query 133 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: date != tinyint; line 1 pos 43 -- !query 134 @@ -1142,7 +1142,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as smallint)) FROM t struct<> -- !query 134 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: date != smallint; line 1 pos 43 -- !query 135 @@ -1151,7 +1151,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as int)) FROM t struct<> -- !query 135 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: date != int; line 1 pos 43 -- !query 136 @@ -1160,7 +1160,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as bigint)) FROM t struct<> -- !query 136 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: date != bigint; line 1 pos 43 -- !query 137 @@ -1169,7 +1169,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as float)) FROM t struct<> -- !query 137 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: date != float; line 1 pos 43 -- !query 138 @@ -1178,7 +1178,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as double)) FROM t struct<> -- !query 138 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: date != double; line 1 pos 43 -- !query 139 @@ -1187,7 +1187,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as decimal(10, 0))) FROM t struct<> -- !query 139 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: date != decimal(10,0); line 1 pos 43 -- !query 140 @@ -1204,7 +1204,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2' as binary)) FROM t struct<> -- !query 141 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: date != binary; line 1 pos 43 -- !query 142 @@ -1213,7 +1213,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as boolean)) FROM t struct<> -- !query 142 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: date != boolean; line 1 pos 43 -- !query 143 @@ -1302,7 +1302,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('1' as binary)) FROM t struct<> -- !query 153 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: tinyint != binary; line 1 pos 26 -- !query 154 @@ -1311,7 +1311,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as boolean)) FROM t struct<> -- !query 154 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: tinyint != boolean; line 1 pos 26 -- !query 155 @@ -1320,7 +1320,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00.0' a struct<> -- !query 155 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: tinyint != timestamp; line 1 pos 26 -- !query 156 @@ -1329,7 +1329,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00' as struct<> -- !query 156 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: tinyint != date; line 1 pos 26 -- !query 157 @@ -1402,7 +1402,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('1' as binary)) FROM t struct<> -- !query 165 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: smallint != binary; line 1 pos 27 -- !query 166 @@ -1411,7 +1411,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as boolean)) FROM t struct<> -- !query 166 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: smallint != boolean; line 1 pos 27 -- !query 167 @@ -1420,7 +1420,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00.0' struct<> -- !query 167 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: smallint != timestamp; line 1 pos 27 -- !query 168 @@ -1429,7 +1429,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00' a struct<> -- !query 168 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: smallint != date; line 1 pos 27 -- !query 169 @@ -1502,7 +1502,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('1' as binary)) FROM t struct<> -- !query 177 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: int != binary; line 1 pos 22 -- !query 178 @@ -1511,7 +1511,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast(1 as boolean)) FROM t struct<> -- !query 178 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: int != boolean; line 1 pos 22 -- !query 179 @@ -1520,7 +1520,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00.0' as timest struct<> -- !query 179 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: int != timestamp; line 1 pos 22 -- !query 180 @@ -1529,7 +1529,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00' as date)) F struct<> -- !query 180 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: int != date; line 1 pos 22 -- !query 181 @@ -1602,7 +1602,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('1' as binary)) FROM t struct<> -- !query 189 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: bigint != binary; line 1 pos 25 -- !query 190 @@ -1611,7 +1611,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as boolean)) FROM t struct<> -- !query 190 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: bigint != boolean; line 1 pos 25 -- !query 191 @@ -1620,7 +1620,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00.0' as struct<> -- !query 191 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: bigint != timestamp; line 1 pos 25 -- !query 192 @@ -1629,7 +1629,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00' as da struct<> -- !query 192 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: bigint != date; line 1 pos 25 -- !query 193 @@ -1702,7 +1702,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('1' as binary)) FROM t struct<> -- !query 201 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: float != binary; line 1 pos 24 -- !query 202 @@ -1711,7 +1711,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast(1 as boolean)) FROM t struct<> -- !query 202 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: float != boolean; line 1 pos 24 -- !query 203 @@ -1720,7 +1720,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00.0' as ti struct<> -- !query 203 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: float != timestamp; line 1 pos 24 -- !query 204 @@ -1729,7 +1729,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00' as date struct<> -- !query 204 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: float != date; line 1 pos 24 -- !query 205 @@ -1802,7 +1802,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('1' as binary)) FROM t struct<> -- !query 213 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: double != binary; line 1 pos 25 -- !query 214 @@ -1811,7 +1811,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast(1 as boolean)) FROM t struct<> -- !query 214 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: double != boolean; line 1 pos 25 -- !query 215 @@ -1820,7 +1820,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00.0' as struct<> -- !query 215 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: double != timestamp; line 1 pos 25 -- !query 216 @@ -1829,7 +1829,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00' as da struct<> -- !query 216 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: double != date; line 1 pos 25 -- !query 217 @@ -1902,7 +1902,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('1' as bina struct<> -- !query 225 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != binary; line 1 pos 33 -- !query 226 @@ -1911,7 +1911,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as boolea struct<> -- !query 226 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != boolean; line 1 pos 33 -- !query 227 @@ -1920,7 +1920,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 struct<> -- !query 227 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != timestamp; line 1 pos 33 -- !query 228 @@ -1929,7 +1929,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 struct<> -- !query 228 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != date; line 1 pos 33 -- !query 229 @@ -2002,7 +2002,7 @@ SELECT cast(1 as string) in (cast(1 as string), cast('1' as binary)) FROM t struct<> -- !query 237 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 -- !query 238 @@ -2011,7 +2011,7 @@ SELECT cast(1 as string) in (cast(1 as string), cast(1 as boolean)) FROM t struct<> -- !query 238 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 -- !query 239 @@ -2036,7 +2036,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as tinyint)) FROM t struct<> -- !query 241 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: binary != tinyint; line 1 pos 27 -- !query 242 @@ -2045,7 +2045,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as smallint)) FROM t struct<> -- !query 242 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: binary != smallint; line 1 pos 27 -- !query 243 @@ -2054,7 +2054,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as int)) FROM t struct<> -- !query 243 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: binary != int; line 1 pos 27 -- !query 244 @@ -2063,7 +2063,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as bigint)) FROM t struct<> -- !query 244 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: binary != bigint; line 1 pos 27 -- !query 245 @@ -2072,7 +2072,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as float)) FROM t struct<> -- !query 245 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: binary != float; line 1 pos 27 -- !query 246 @@ -2081,7 +2081,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as double)) FROM t struct<> -- !query 246 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: binary != double; line 1 pos 27 -- !query 247 @@ -2090,7 +2090,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as decimal(10, 0))) F struct<> -- !query 247 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: binary != decimal(10,0); line 1 pos 27 -- !query 248 @@ -2099,7 +2099,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as string)) FROM t struct<> -- !query 248 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 -- !query 249 @@ -2116,7 +2116,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as boolean)) FROM t struct<> -- !query 250 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: binary != boolean; line 1 pos 27 -- !query 251 @@ -2125,7 +2125,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00.0' struct<> -- !query 251 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: binary != timestamp; line 1 pos 27 -- !query 252 @@ -2134,7 +2134,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00' a struct<> -- !query 252 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: binary != date; line 1 pos 27 -- !query 253 @@ -2143,7 +2143,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as tinyint)) FROM t struct<> -- !query 253 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: boolean != tinyint; line 1 pos 28 -- !query 254 @@ -2152,7 +2152,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as smallint)) FROM struct<> -- !query 254 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: boolean != smallint; line 1 pos 28 -- !query 255 @@ -2161,7 +2161,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as int)) FROM t struct<> -- !query 255 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: boolean != int; line 1 pos 28 -- !query 256 @@ -2170,7 +2170,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as bigint)) FROM t struct<> -- !query 256 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: boolean != bigint; line 1 pos 28 -- !query 257 @@ -2179,7 +2179,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as float)) FROM t struct<> -- !query 257 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: boolean != float; line 1 pos 28 -- !query 258 @@ -2188,7 +2188,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as double)) FROM t struct<> -- !query 258 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: boolean != double; line 1 pos 28 -- !query 259 @@ -2197,7 +2197,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as decimal(10, 0))) struct<> -- !query 259 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: boolean != decimal(10,0); line 1 pos 28 -- !query 260 @@ -2206,7 +2206,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as string)) FROM t struct<> -- !query 260 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 28 -- !query 261 @@ -2215,7 +2215,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('1' as binary)) FROM struct<> -- !query 261 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: boolean != binary; line 1 pos 28 -- !query 262 @@ -2232,7 +2232,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00. struct<> -- !query 263 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: boolean != timestamp; line 1 pos 28 -- !query 264 @@ -2241,7 +2241,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00' struct<> -- !query 264 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: boolean != date; line 1 pos 28 -- !query 265 @@ -2250,7 +2250,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 265 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != tinyint; line 1 pos 50 -- !query 266 @@ -2259,7 +2259,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 266 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != smallint; line 1 pos 50 -- !query 267 @@ -2268,7 +2268,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 267 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: timestamp != int; line 1 pos 50 -- !query 268 @@ -2277,7 +2277,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 268 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != bigint; line 1 pos 50 -- !query 269 @@ -2286,7 +2286,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 269 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: timestamp != float; line 1 pos 50 -- !query 270 @@ -2295,7 +2295,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 270 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: timestamp != double; line 1 pos 50 -- !query 271 @@ -2304,7 +2304,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 271 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: timestamp != decimal(10,0); line 1 pos 50 -- !query 272 @@ -2321,7 +2321,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 273 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: timestamp != binary; line 1 pos 50 -- !query 274 @@ -2330,7 +2330,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 274 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: timestamp != boolean; line 1 pos 50 -- !query 275 @@ -2355,7 +2355,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 277 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: date != tinyint; line 1 pos 43 -- !query 278 @@ -2364,7 +2364,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 278 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: date != smallint; line 1 pos 43 -- !query 279 @@ -2373,7 +2373,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 279 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: date != int; line 1 pos 43 -- !query 280 @@ -2382,7 +2382,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 280 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: date != bigint; line 1 pos 43 -- !query 281 @@ -2391,7 +2391,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 281 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: date != float; line 1 pos 43 -- !query 282 @@ -2400,7 +2400,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 282 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: date != double; line 1 pos 43 -- !query 283 @@ -2409,7 +2409,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 283 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: date != decimal(10,0); line 1 pos 43 -- !query 284 @@ -2426,7 +2426,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 285 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: date != binary; line 1 pos 43 -- !query 286 @@ -2435,7 +2435,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 286 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: date != boolean; line 1 pos 43 -- !query 287 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out index 5dd257ba6a0b..01d83938031f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out @@ -168,7 +168,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWE struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'StringType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 -- !query 21 @@ -177,7 +177,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BET struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'BinaryType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 +cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'binary' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 21 -- !query 22 @@ -186,7 +186,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETW struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'BooleanType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 -- !query 23 @@ -195,7 +195,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as ti struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 +cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'timestamp' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 21 -- !query 24 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index a52e198eb9a8..133458ae9303 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -61,7 +61,7 @@ ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType' does not match the expected data type 'int'.; line 1 pos 41 +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'bigint' does not match the expected data type 'int'.; line 1 pos 41 -- !query 4 @@ -221,7 +221,7 @@ RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'timestamp' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 33 -- !query 15 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bd1e7adefc7a..d535896723bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -660,7 +660,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.as[KryoData] }.message - assert(e.contains("cannot cast IntegerType to BinaryType")) + assert(e.contains("cannot cast int to binary")) } test("Java encoder") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 6b98209fd49b..109fcf90a3ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -65,7 +65,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val m3 = intercept[AnalysisException] { df.selectExpr("stack(2, 1, '2.2')") }.getMessage - assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)")) + assert(m3.contains("data type mismatch: Argument 1 (int) != Argument 2 (string)")) // stack on column data val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c") @@ -80,7 +80,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val m5 = intercept[AnalysisException] { df3.selectExpr("stack(2, a, b)") }.getMessage - assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)")) + assert(m5.contains("data type mismatch: Argument 1 (int) != Argument 2 (double)")) } From 12d20dd75b1620da362dbb5345bed58e47ddacb9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 25 Dec 2017 20:29:10 +0900 Subject: [PATCH 009/100] [SPARK-22874][PYSPARK][SQL][FOLLOW-UP] Modify error messages to show actual versions. ## What changes were proposed in this pull request? This is a follow-up pr of #20054 modifying error messages for both pandas and pyarrow to show actual versions. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20074 from ueshin/issues/SPARK-22874_fup1. --- python/pyspark/sql/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index fb7d42a35d8f..08c34c6dccc5 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -118,7 +118,8 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion import pandas if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process") + raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " + "however, your version was %s." % pandas.__version__) def require_minimum_pyarrow_version(): @@ -127,4 +128,5 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion import pyarrow if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process") + raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " + "however, your version was %s." % pyarrow.__version__) From be03d3ad793dfe8f3ec33074d4ae95f5adb86ee4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 25 Dec 2017 16:17:39 -0800 Subject: [PATCH 010/100] [SPARK-22893][SQL][HOTFIX] Fix a error message of VersionsSuite ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/20064 breaks Jenkins tests because it missed to update one error message for Hive 0.12 and Hive 0.13. This PR fixes that. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/3924/ - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/3977/ - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7/4226/ - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4260/ ## How was this patch tested? Pass the Jenkins without failure. Author: Dongjoon Hyun Closes #20079 from dongjoon-hyun/SPARK-22893. --- .../scala/org/apache/spark/sql/hive/client/VersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9d15dabc8d3f..94473a08dd31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -773,7 +773,7 @@ class VersionsSuite extends SparkFunSuite with Logging { """.stripMargin ) - val errorMsg = "data type mismatch: cannot cast DecimalType(2,1) to BinaryType" + val errorMsg = "data type mismatch: cannot cast decimal(2,1) to binary" if (isPartitioned) { val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" From 0e6833006d28df426eb132bb8fc82917b8e2aedd Mon Sep 17 00:00:00 2001 From: Yash Sharma Date: Tue, 26 Dec 2017 09:50:39 +0200 Subject: [PATCH 011/100] [SPARK-20168][DSTREAM] Add changes to use kinesis fetches from specific timestamp ## What changes were proposed in this pull request? Kinesis client can resume from a specified timestamp while creating a stream. We should have option to pass a timestamp in config to allow kinesis to resume from the given timestamp. The patch introduces a new `KinesisInitialPositionInStream` that takes the `InitialPositionInStream` with the `timestamp` information that can be used to resume kinesis fetches from the provided timestamp. ## How was this patch tested? Unit Tests cc : budde brkyvz Author: Yash Sharma Closes #18029 from yssharma/ysharma/kcl_resume. --- .../kinesis/KinesisInitialPositions.java | 91 +++++++++++++++++++ .../streaming/KinesisWordCountASL.scala | 5 +- .../kinesis/KinesisInputDStream.scala | 31 +++++-- .../streaming/kinesis/KinesisReceiver.scala | 45 +++++---- .../streaming/kinesis/KinesisUtils.scala | 15 ++- .../JavaKinesisInputDStreamBuilderSuite.java | 47 ++++++++-- .../KinesisInputDStreamBuilderSuite.scala | 68 +++++++++++++- .../kinesis/KinesisStreamSuite.scala | 11 ++- 8 files changed, 264 insertions(+), 49 deletions(-) create mode 100644 external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java new file mode 100644 index 000000000000..206e1e469903 --- /dev/null +++ b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +import java.io.Serializable; +import java.util.Date; + +/** + * A java wrapper for exposing [[InitialPositionInStream]] + * to the corresponding Kinesis readers. + */ +interface KinesisInitialPosition { + InitialPositionInStream getPosition(); +} + +public class KinesisInitialPositions { + public static class Latest implements KinesisInitialPosition, Serializable { + public Latest() {} + + @Override + public InitialPositionInStream getPosition() { + return InitialPositionInStream.LATEST; + } + } + + public static class TrimHorizon implements KinesisInitialPosition, Serializable { + public TrimHorizon() {} + + @Override + public InitialPositionInStream getPosition() { + return InitialPositionInStream.TRIM_HORIZON; + } + } + + public static class AtTimestamp implements KinesisInitialPosition, Serializable { + private Date timestamp; + + public AtTimestamp(Date timestamp) { + this.timestamp = timestamp; + } + + @Override + public InitialPositionInStream getPosition() { + return InitialPositionInStream.AT_TIMESTAMP; + } + + public Date getTimestamp() { + return timestamp; + } + } + + + /** + * Returns instance of [[KinesisInitialPosition]] based on the passed [[InitialPositionInStream]]. + * This method is used in KinesisUtils for translating the InitialPositionInStream + * to InitialPosition. This function would be removed when we deprecate the KinesisUtils. + * + * @return [[InitialPosition]] + */ + public static KinesisInitialPosition fromKinesisInitialPosition( + InitialPositionInStream initialPositionInStream) throws UnsupportedOperationException { + if (initialPositionInStream == InitialPositionInStream.LATEST) { + return new Latest(); + } else if (initialPositionInStream == InitialPositionInStream.TRIM_HORIZON) { + return new TrimHorizon(); + } else { + // InitialPositionInStream.AT_TIMESTAMP is not supported. + // Use InitialPosition.atTimestamp(timestamp) instead. + throw new UnsupportedOperationException( + "Only InitialPositionInStream.LATEST and InitialPositionInStream.TRIM_HORIZON " + + "supported in initialPositionInStream(). Please use the initialPosition() from " + + "builder API in KinesisInputDStream for using InitialPositionInStream.AT_TIMESTAMP"); + } + } +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index cde2c4b04c0c..fcb790e3ea1f 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -24,7 +24,6 @@ import scala.util.Random import com.amazonaws.auth.DefaultAWSCredentialsProviderChain import com.amazonaws.services.kinesis.AmazonKinesisClient -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest import org.apache.log4j.{Level, Logger} @@ -33,9 +32,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.kinesis.KinesisInputDStream - /** * Consumes messages from a Amazon Kinesis streams and does wordcount. * @@ -139,7 +138,7 @@ object KinesisWordCountASL extends Logging { .streamName(streamName) .endpointUrl(endpointUrl) .regionName(regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointAppName(appName) .checkpointInterval(kinesisCheckpointInterval) .storageLevel(StorageLevel.MEMORY_AND_DISK_2) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index f61e398e7bdd..1ffec01df9f0 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -28,6 +28,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo @@ -36,7 +37,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( val streamName: String, val endpointUrl: String, val regionName: String, - val initialPositionInStream: InitialPositionInStream, + val initialPosition: KinesisInitialPosition, val checkpointAppName: String, val checkpointInterval: Duration, val _storageLevel: StorageLevel, @@ -77,7 +78,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( } override def getReceiver(): Receiver[T] = { - new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, + new KinesisReceiver(streamName, endpointUrl, regionName, initialPosition, checkpointAppName, checkpointInterval, _storageLevel, messageHandler, kinesisCreds, dynamoDBCreds, cloudWatchCreds) } @@ -100,7 +101,7 @@ object KinesisInputDStream { // Params with defaults private var endpointUrl: Option[String] = None private var regionName: Option[String] = None - private var initialPositionInStream: Option[InitialPositionInStream] = None + private var initialPosition: Option[KinesisInitialPosition] = None private var checkpointInterval: Option[Duration] = None private var storageLevel: Option[StorageLevel] = None private var kinesisCredsProvider: Option[SparkAWSCredentials] = None @@ -180,16 +181,32 @@ object KinesisInputDStream { this } + /** + * Sets the initial position data is read from in the Kinesis stream. Defaults to + * [[KinesisInitialPositions.Latest]] if no custom value is specified. + * + * @param initialPosition [[KinesisInitialPosition]] value specifying where Spark Streaming + * will start reading records in the Kinesis stream from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def initialPosition(initialPosition: KinesisInitialPosition): Builder = { + this.initialPosition = Option(initialPosition) + this + } + /** * Sets the initial position data is read from in the Kinesis stream. Defaults to * [[InitialPositionInStream.LATEST]] if no custom value is specified. + * This function would be removed when we deprecate the KinesisUtils. * * @param initialPosition InitialPositionInStream value specifying where Spark Streaming * will start reading records in the Kinesis stream from * @return Reference to this [[KinesisInputDStream.Builder]] */ + @deprecated("use initialPosition(initialPosition: KinesisInitialPosition)", "2.3.0") def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = { - initialPositionInStream = Option(initialPosition) + this.initialPosition = Option( + KinesisInitialPositions.fromKinesisInitialPosition(initialPosition)) this } @@ -266,7 +283,7 @@ object KinesisInputDStream { getRequiredParam(streamName, "streamName"), endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL), regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME), - initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM), + initialPosition.getOrElse(DEFAULT_INITIAL_POSITION), getRequiredParam(checkpointAppName, "checkpointAppName"), checkpointInterval.getOrElse(ssc.graph.batchDuration), storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), @@ -293,7 +310,6 @@ object KinesisInputDStream { * Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances. * * @since 2.2.0 - * * @return [[KinesisInputDStream.Builder]] instance */ def builder: Builder = new Builder @@ -309,7 +325,6 @@ object KinesisInputDStream { private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String = "https://kinesis.us-east-1.amazonaws.com" private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" - private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream = - InitialPositionInStream.LATEST + private[kinesis] val DEFAULT_INITIAL_POSITION: KinesisInitialPosition = new Latest() private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1026d0fcb59b..fa0de6298a5f 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -24,12 +24,13 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.AtTimestamp import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils @@ -56,12 +57,13 @@ import org.apache.spark.util.Utils * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Region name used by the Kinesis Client Library for * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). + * @param initialPosition Instance of [[KinesisInitialPosition]] + * In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * ([[KinesisInitialPositions.TrimHorizon]]) or + * the tip of the stream ([[KinesisInitialPositions.Latest]]). * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, * the KCL will throw errors. This usually requires deleting the backing @@ -83,7 +85,7 @@ private[kinesis] class KinesisReceiver[T]( val streamName: String, endpointUrl: String, regionName: String, - initialPositionInStream: InitialPositionInStream, + initialPosition: KinesisInitialPosition, checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, @@ -148,18 +150,29 @@ private[kinesis] class KinesisReceiver[T]( kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) val kinesisProvider = kinesisCreds.provider - val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( - checkpointAppName, - streamName, - kinesisProvider, - dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), - cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), - workerId) + + val kinesisClientLibConfiguration = { + val baseClientLibConfiguration = new KinesisClientLibConfiguration( + checkpointAppName, + streamName, + kinesisProvider, + dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), + cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), + workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream) + .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) + // Update the Kinesis client lib config with timestamp + // if InitialPositionInStream.AT_TIMESTAMP is passed + initialPosition match { + case ts: AtTimestamp => + baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) + case _ => baseClientLibConfiguration + } + } + /* * RecordProcessorFactory creates impls of IRecordProcessor. * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 1298463bfba1..2500460bd330 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -73,7 +73,8 @@ object KinesisUtils { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, cleanedHandler, DefaultCredentials, None, None) } } @@ -129,7 +130,8 @@ object KinesisUtils { awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, cleanedHandler, kinesisCredsProvider, None, None) } } @@ -198,7 +200,8 @@ object KinesisUtils { awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey)) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, cleanedHandler, kinesisCredsProvider, None, None) } } @@ -243,7 +246,8 @@ object KinesisUtils { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) } } @@ -293,7 +297,8 @@ object KinesisUtils { awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) } } diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java index be6d549b1e42..03becd73d1a0 100644 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -17,14 +17,13 @@ package org.apache.spark.streaming.kinesis; -import org.junit.Test; - import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; - +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.TrimHorizon; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.Seconds; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.Seconds; +import org.junit.Test; public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { /** @@ -35,7 +34,41 @@ public void testJavaKinesisDStreamBuilder() { String streamName = "a-very-nice-stream-name"; String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; String region = "us-west-2"; - InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON; + KinesisInitialPosition initialPosition = new TrimHorizon(); + String appName = "a-very-nice-kinesis-app"; + Duration checkpointInterval = Seconds.apply(30); + StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); + + KinesisInputDStream kinesisDStream = KinesisInputDStream.builder() + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPosition(initialPosition) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build(); + assert(kinesisDStream.streamName() == streamName); + assert(kinesisDStream.endpointUrl() == endpointUrl); + assert(kinesisDStream.regionName() == region); + assert(kinesisDStream.initialPosition().getPosition() == initialPosition.getPosition()); + assert(kinesisDStream.checkpointAppName() == appName); + assert(kinesisDStream.checkpointInterval() == checkpointInterval); + assert(kinesisDStream._storageLevel() == storageLevel); + ssc.stop(); + } + + /** + * Test to ensure that the old API for InitialPositionInStream + * is supported in KinesisDStream.Builder. + * This test would be removed when we deprecate the KinesisUtils. + */ + @Test + public void testJavaKinesisDStreamBuilderOldApi() { + String streamName = "a-very-nice-stream-name"; + String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; + String region = "us-west-2"; String appName = "a-very-nice-kinesis-app"; Duration checkpointInterval = Seconds.apply(30); StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); @@ -45,7 +78,7 @@ public void testJavaKinesisDStreamBuilder() { .streamName(streamName) .endpointUrl(endpointUrl) .regionName(region) - .initialPositionInStream(initialPosition) + .initialPositionInStream(InitialPositionInStream.LATEST) .checkpointAppName(appName) .checkpointInterval(checkpointInterval) .storageLevel(storageLevel) @@ -53,7 +86,7 @@ public void testJavaKinesisDStreamBuilder() { assert(kinesisDStream.streamName() == streamName); assert(kinesisDStream.endpointUrl() == endpointUrl); assert(kinesisDStream.regionName() == region); - assert(kinesisDStream.initialPositionInStream() == initialPosition); + assert(kinesisDStream.initialPosition().getPosition() == InitialPositionInStream.LATEST); assert(kinesisDStream.checkpointAppName() == appName); assert(kinesisDStream.checkpointInterval() == checkpointInterval); assert(kinesisDStream._storageLevel() == storageLevel); diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala index afa1a7f8ca66..e0e26847aa0e 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.streaming.kinesis +import java.util.Calendar + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Duration, Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.{AtTimestamp, TrimHorizon} class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach with MockitoSugar { @@ -69,7 +72,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE val dstream = builder.build() assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL) assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME) - assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM) + assert(dstream.initialPosition == DEFAULT_INITIAL_POSITION) assert(dstream.checkpointInterval == batchDuration) assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL) assert(dstream.kinesisCreds == DefaultCredentials) @@ -80,7 +83,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE test("should propagate custom non-auth values to KinesisInputDStream") { val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" val customRegion = "us-west-2" - val customInitialPosition = InitialPositionInStream.TRIM_HORIZON + val customInitialPosition = new TrimHorizon() val customAppName = "a-very-nice-kinesis-app" val customCheckpointInterval = Seconds(30) val customStorageLevel = StorageLevel.MEMORY_ONLY @@ -91,7 +94,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE val dstream = builder .endpointUrl(customEndpointUrl) .regionName(customRegion) - .initialPositionInStream(customInitialPosition) + .initialPosition(customInitialPosition) .checkpointAppName(customAppName) .checkpointInterval(customCheckpointInterval) .storageLevel(customStorageLevel) @@ -101,12 +104,67 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE .build() assert(dstream.endpointUrl == customEndpointUrl) assert(dstream.regionName == customRegion) - assert(dstream.initialPositionInStream == customInitialPosition) + assert(dstream.initialPosition == customInitialPosition) assert(dstream.checkpointAppName == customAppName) assert(dstream.checkpointInterval == customCheckpointInterval) assert(dstream._storageLevel == customStorageLevel) assert(dstream.kinesisCreds == customKinesisCreds) assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + + // Testing with AtTimestamp + val cal = Calendar.getInstance() + cal.add(Calendar.DATE, -1) + val timestamp = cal.getTime() + val initialPositionAtTimestamp = new AtTimestamp(timestamp) + + val dstreamAtTimestamp = builder + .endpointUrl(customEndpointUrl) + .regionName(customRegion) + .initialPosition(initialPositionAtTimestamp) + .checkpointAppName(customAppName) + .checkpointInterval(customCheckpointInterval) + .storageLevel(customStorageLevel) + .kinesisCredentials(customKinesisCreds) + .dynamoDBCredentials(customDynamoDBCreds) + .cloudWatchCredentials(customCloudWatchCreds) + .build() + assert(dstreamAtTimestamp.endpointUrl == customEndpointUrl) + assert(dstreamAtTimestamp.regionName == customRegion) + assert(dstreamAtTimestamp.initialPosition.getPosition + == initialPositionAtTimestamp.getPosition) + assert( + dstreamAtTimestamp.initialPosition.asInstanceOf[AtTimestamp].getTimestamp.equals(timestamp)) + assert(dstreamAtTimestamp.checkpointAppName == customAppName) + assert(dstreamAtTimestamp.checkpointInterval == customCheckpointInterval) + assert(dstreamAtTimestamp._storageLevel == customStorageLevel) + assert(dstreamAtTimestamp.kinesisCreds == customKinesisCreds) + assert(dstreamAtTimestamp.dynamoDBCreds == Option(customDynamoDBCreds)) + assert(dstreamAtTimestamp.cloudWatchCreds == Option(customCloudWatchCreds)) + } + + test("old Api should throw UnsupportedOperationExceptionexception with AT_TIMESTAMP") { + val streamName: String = "a-very-nice-stream-name" + val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com" + val region: String = "us-west-2" + val appName: String = "a-very-nice-kinesis-app" + val checkpointInterval: Duration = Seconds.apply(30) + val storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + // This should not build. + // InitialPositionInStream.AT_TIMESTAMP is not supported in old Api. + // The builder Api in KinesisInputDStream should be used. + intercept[UnsupportedOperationException] { + val kinesisDStream: KinesisInputDStream[Array[Byte]] = KinesisInputDStream.builder + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPositionInStream(InitialPositionInStream.AT_TIMESTAMP) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build + } } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 7e5bda923f63..a7a68eba910b 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._ import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -178,7 +179,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(testUtils.streamName) .endpointUrl(testUtils.endpointUrl) .regionName(testUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() @@ -209,7 +210,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(testUtils.streamName) .endpointUrl(testUtils.endpointUrl) .regionName(testUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .buildWithMessageHandler(addFive(_)) @@ -245,7 +246,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName("dummyStream") .endpointUrl(dummyEndpointUrl) .regionName(dummyRegionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() @@ -293,7 +294,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(localTestUtils.streamName) .endpointUrl(localTestUtils.endpointUrl) .regionName(localTestUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() @@ -369,7 +370,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(testUtils.streamName) .endpointUrl(testUtils.endpointUrl) .regionName(testUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() From eb386be1ed383323da6e757f63f3b8a7ced38cc4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 26 Dec 2017 21:37:25 +0900 Subject: [PATCH 012/100] [SPARK-21552][SQL] Add DecimalType support to ArrowWriter. ## What changes were proposed in this pull request? Decimal type is not yet supported in `ArrowWriter`. This is adding the decimal type support. ## How was this patch tested? Added a test to `ArrowConvertersSuite`. Author: Takuya UESHIN Closes #18754 from ueshin/issues/SPARK-21552. --- python/pyspark/sql/tests.py | 61 +++++++++++------ python/pyspark/sql/types.py | 2 +- .../sql/execution/arrow/ArrowWriter.scala | 21 ++++++ .../arrow/ArrowConvertersSuite.scala | 67 ++++++++++++++++++- .../execution/arrow/ArrowWriterSuite.scala | 2 + 5 files changed, 131 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b977160af566..b811a0f8a31d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3142,6 +3142,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime + from decimal import Decimal ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -3158,11 +3159,15 @@ def setUpClass(cls): StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), StructField("5_double_t", DoubleType(), True), - StructField("6_date_t", DateType(), True), - StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3190,10 +3195,11 @@ def create_pandas_data_frame(self): return pd.DataFrame(data=data_dict) def test_unsupported_datatype(self): - schema = StructType([StructField("decimal", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.toPandas() def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3293,7 +3299,7 @@ def test_createDataFrame_respect_session_timezone(self): self.assertNotEqual(result_ny, result_la) # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) @@ -3317,11 +3323,11 @@ def test_createDataFrame_with_incorrect_schema(self): def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -3344,7 +3350,7 @@ def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3514,6 +3520,7 @@ def test_vectorized_udf_basic(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, StringType()) @@ -3521,10 +3528,12 @@ def test_vectorized_udf_basic(self): long_f = pandas_udf(f, LongType()) float_f = pandas_udf(f, FloatType()) double_f = pandas_udf(f, DoubleType()) + decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_boolean(self): @@ -3590,6 +3599,16 @@ def test_vectorized_udf_null_double(self): res = df.select(double_f(col('double'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_decimal(self): + from decimal import Decimal + from pyspark.sql.functions import pandas_udf, col + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) + res = df.select(decimal_f(col('decimal'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_string(self): from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] @@ -3607,6 +3626,7 @@ def test_vectorized_udf_datatype_string(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, 'string') @@ -3614,10 +3634,12 @@ def test_vectorized_udf_datatype_string(self): long_f = pandas_udf(f, 'long') float_f = pandas_udf(f, 'float') double_f = pandas_udf(f, 'double') + decimal_f = pandas_udf(f, 'decimal(38, 18)') bool_f = pandas_udf(f, 'boolean') res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_complex(self): @@ -3713,12 +3735,12 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, DecimalType()) + f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('dt'))).collect() + df.select(f(col('map'))).collect() def test_vectorized_udf_null_date(self): from pyspark.sql.functions import pandas_udf, col @@ -4012,7 +4034,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + [StructField("id", LongType(), True), + StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 063264a89379..02b245713b6a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1617,7 +1617,7 @@ def to_arrow_type(dt): elif type(dt) == DoubleType: arrow_type = pa.float64() elif type(dt) == DecimalType: - arrow_type = pa.decimal(dt.precision, dt.scale) + arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() elif type(dt) == DateType: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0258056d9de4..22b63513548f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,6 +53,8 @@ object ArrowWriter { case (LongType, vector: BigIntVector) => new LongWriter(vector) case (FloatType, vector: Float4Vector) => new FloatWriter(vector) case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) case (StringType, vector: VarCharVector) => new StringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: DateDayVector) => new DateWriter(vector) @@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi } } +private[arrow] class DecimalWriter( + val valueVector: DecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val decimal = input.getDecimal(ordinal, precision, scale) + if (decimal.changePrecision(precision, scale)) { + valueVector.setSafe(count, decimal.toJavaBigDecimal) + } else { + setNull() + } + } +} + private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index fd5a3df6abc6..261df06100ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } + test("decimal conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | } ] + | }, + | "batches" : [ { + | "count" : 7, + | "columns" : [ { + | "name" : "a_d", + | "count" : 7, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ + | "1000000000000000000", + | "2000000000000000000", + | "10000000000000000", + | "200000000000000000000", + | "100000000000000", + | "20000000000000000000000", + | "30000000000000000000" ] + | }, { + | "name" : "b_d", + | "count" : 7, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ], + | "DATA" : [ + | "1100000000000000000", + | "0", + | "0", + | "2200000000000000000", + | "0", + | "3300000000000000000", + | "0" ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), + Some(Decimal("123456789012345678901234567890"))) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "decimalData.json") + } + test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index a71e30aa3ca9..508c116aae92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLong(rowId) case FloatType => reader.getFloat(rowId) case DoubleType => reader.getDouble(rowId) + case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) @@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, Seq(1L, 2L, null, 4L)) check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4))) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) check(DateType, Seq(0, 1, 2, null, 4)) From ff48b1b338241039a7189e7a3c04333b1256fdb3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 26 Dec 2017 06:39:40 -0800 Subject: [PATCH 013/100] [SPARK-22901][PYTHON] Add deterministic flag to pyspark UDF ## What changes were proposed in this pull request? In SPARK-20586 the flag `deterministic` was added to Scala UDF, but it is not available for python UDF. This flag is useful for cases when the UDF's code can return different result with the same input. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. This can lead to unexpected behavior. This PR adds the deterministic flag, via the `asNondeterministic` method, to let the user mark the function as non-deterministic and therefore avoid the optimizations which might lead to strange behaviors. ## How was this patch tested? Manual tests: ``` >>> from pyspark.sql.functions import * >>> from pyspark.sql.types import * >>> df_br = spark.createDataFrame([{'name': 'hello'}]) >>> import random >>> udf_random_col = udf(lambda: int(100*random.random()), IntegerType()).asNondeterministic() >>> df_br = df_br.withColumn('RAND', udf_random_col()) >>> random.seed(1234) >>> udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) >>> df_br.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).show() +-----+----+-------------+ | name|RAND|RAND_PLUS_TEN| +-----+----+-------------+ |hello| 3| 13| +-----+----+-------------+ ``` Author: Marco Gaido Author: Marco Gaido Closes #19929 from mgaido91/SPARK-22629. --- .../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++ python/pyspark/sql/functions.py | 11 ++++++++--- python/pyspark/sql/tests.py | 9 +++++++++ python/pyspark/sql/udf.py | 13 ++++++++++++- .../org/apache/spark/sql/UDFRegistration.scala | 5 +++-- .../spark/sql/execution/python/PythonUDF.scala | 5 ++++- .../python/UserDefinedPythonFunction.scala | 5 +++-- .../execution/python/BatchEvalPythonExecSuite.scala | 3 ++- 8 files changed, 48 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 93d508c28ebb..1ec0e717fac2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,6 +39,13 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + + def toString(pythonEvalType: Int): String = pythonEvalType match { + case NON_UDF => "NON_UDF" + case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" + case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" + case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + } } /** diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ddd8df3b15bf..66ee033bab99 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2093,9 +2093,14 @@ class PandasUDFType(object): def udf(f=None, returnType=StringType()): """Creates a user defined function (UDF). - .. note:: The user-defined functions must be deterministic. Due to optimization, - duplicate invocations may be eliminated or the function may even be invoked more times than - it is present in the query. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> from pyspark.sql.types import IntegerType + >>> import random + >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b811a0f8a31d..3ef152288732 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -435,6 +435,15 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 123138117fdc..54b5a8656e1c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -92,6 +92,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType + self._deterministic = True @property def returnType(self): @@ -129,7 +130,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType) + self._name, wrapped_func, jdt, self.evalType, self._deterministic) return judf def __call__(self, *cols): @@ -161,5 +162,15 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType + wrapper.asNondeterministic = self.asNondeterministic return wrapper + + def asNondeterministic(self): + """ + Updates UserDefinedFunction to nondeterministic. + + .. versionadded:: 2.3 + """ + self._deterministic = False + return self diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3ff476147b8b..dc2468a721e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.Try import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} @@ -41,8 +42,6 @@ import org.apache.spark.util.Utils * spark.udf * }}} * - * @note The user-defined functions must be deterministic. - * * @since 1.3.0 */ @InterfaceStability.Stable @@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | pythonIncludes: ${udf.func.pythonIncludes} | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} + | pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)} + | udfDeterministic: ${udf.udfDeterministic} """.stripMargin) functionRegistry.createOrReplaceTempFunction(name, udf.builder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index ef27fbc2db7d..d3f743d9eb61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,9 +29,12 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - evalType: Int) + evalType: Int, + udfDeterministic: Boolean) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override def toString: String = s"$name(${children.mkString(", ")})" override def nullable: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 348e49e473ed..50dca32cb786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -29,10 +29,11 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - pythonEvalType: Int) { + pythonEvalType: Int, + udfDeterministic: Boolean) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, pythonEvalType) + PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 53d3f3456751..9e4a2e877695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - pythonEvalType = PythonEvalType.SQL_BATCHED_UDF) + pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) From 9348e684208465a8f75c893bdeaa30fc42c0cb5f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Dec 2017 09:37:39 -0800 Subject: [PATCH 014/100] [SPARK-22833][EXAMPLE] Improvement SparkHive Scala Examples ## What changes were proposed in this pull request? Some improvements: 1. Point out we are using both Spark SQ native syntax and HQL syntax in the example 2. Avoid using the same table name with temp view, to not confuse users. 3. Create the external hive table with a directory that already has data, which is a more common use case. 4. Remove the usage of `spark.sql.parquet.writeLegacyFormat`. This config was introduced by https://github.com/apache/spark/pull/8566 and has nothing to do with Hive. 5. Remove `repartition` and `coalesce` example. These 2 are not Hive specific, we should put them in a different example file. BTW they can't accurately control the number of output files, `spark.sql.files.maxRecordsPerFile` also controls it. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20081 from cloud-fan/minor. --- .../examples/sql/hive/SparkHiveExample.scala | 75 +++++++++++-------- .../apache/spark/sql/internal/SQLConf.scala | 4 +- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index b193bd595127..70fb5b27366e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -102,40 +102,53 @@ object SparkHiveExample { // | 5| val_5| 5| val_5| // ... - // Create Hive managed table with Parquet - sql("CREATE TABLE records(key int, value string) STORED AS PARQUET") - // Save DataFrame to Hive managed table as Parquet format - val hiveTableDF = sql("SELECT * FROM records") - hiveTableDF.write.mode(SaveMode.Overwrite).saveAsTable("database_name.records") - // Create External Hive table with Parquet - sql("CREATE EXTERNAL TABLE records(key int, value string) " + - "STORED AS PARQUET LOCATION '/user/hive/warehouse/'") - // to make Hive Parquet format compatible with Spark Parquet format - spark.sqlContext.setConf("spark.sql.parquet.writeLegacyFormat", "true") - - // Multiple Parquet files could be created accordingly to volume of data under directory given. - val hiveExternalTableLocation = "/user/hive/warehouse/database_name.db/records" - - // Save DataFrame to Hive External table as compatible Parquet format - hiveTableDF.write.mode(SaveMode.Overwrite).parquet(hiveExternalTableLocation) - - // Turn on flag for Dynamic Partitioning - spark.sqlContext.setConf("hive.exec.dynamic.partition", "true") - spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict") - - // You can create partitions in Hive table, so downstream queries run much faster. - hiveTableDF.write.mode(SaveMode.Overwrite).partitionBy("key") - .parquet(hiveExternalTableLocation) + // Create a Hive managed Parquet table, with HQL syntax instead of the Spark SQL native syntax + // `USING hive` + sql("CREATE TABLE hive_records(key int, value string) STORED AS PARQUET") + // Save DataFrame to the Hive managed table + val df = spark.table("src") + df.write.mode(SaveMode.Overwrite).saveAsTable("hive_records") + // After insertion, the Hive managed table has data now + sql("SELECT * FROM hive_records").show() + // +---+-------+ + // |key| value| + // +---+-------+ + // |238|val_238| + // | 86| val_86| + // |311|val_311| + // ... - // Reduce number of files for each partition by repartition - hiveTableDF.repartition($"key").write.mode(SaveMode.Overwrite) - .partitionBy("key").parquet(hiveExternalTableLocation) + // Prepare a Parquet data directory + val dataDir = "/tmp/parquet_data" + spark.range(10).write.parquet(dataDir) + // Create a Hive external Parquet table + sql(s"CREATE EXTERNAL TABLE hive_ints(key int) STORED AS PARQUET LOCATION '$dataDir'") + // The Hive external table should already have data + sql("SELECT * FROM hive_ints").show() + // +---+ + // |key| + // +---+ + // | 0| + // | 1| + // | 2| + // ... - // Control the number of files in each partition by coalesce - hiveTableDF.coalesce(10).write.mode(SaveMode.Overwrite) - .partitionBy("key").parquet(hiveExternalTableLocation) - // $example off:spark_hive$ + // Turn on flag for Hive Dynamic Partitioning + spark.sqlContext.setConf("hive.exec.dynamic.partition", "true") + spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict") + // Create a Hive partitioned table using DataFrame API + df.write.partitionBy("key").format("hive").saveAsTable("hive_part_tbl") + // Partitioned column `key` will be moved to the end of the schema. + sql("SELECT * FROM hive_part_tbl").show() + // +-------+---+ + // | value|key| + // +-------+---+ + // |val_238|238| + // | val_86| 86| + // |val_311|311| + // ... spark.stop() + // $example off:spark_hive$ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84fe4bb711a4..f16972e5427e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -336,8 +336,8 @@ object SQLConf { .createWithDefault(true) val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") - .doc("Whether to follow Parquet's format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa.") + .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") .booleanConf .createWithDefault(false) From 91d1b300d467cc91948f71e87b7afe1b9027e60f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 Dec 2017 09:40:41 -0800 Subject: [PATCH 015/100] [SPARK-22894][SQL] DateTimeOperations should accept SQL like string type ## What changes were proposed in this pull request? `DateTimeOperations` accept [`StringType`](https://github.com/apache/spark/blob/ae998ec2b5548b7028d741da4813473dde1ad81e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L669), but: ``` spark-sql> SELECT '2017-12-24' + interval 2 months 2 seconds; Error in query: cannot resolve '(CAST('2017-12-24' AS DOUBLE) + interval 2 months 2 seconds)' due to data type mismatch: differing types in '(CAST('2017-12-24' AS DOUBLE) + interval 2 months 2 seconds)' (double and calendarinterval).; line 1 pos 7; 'Project [unresolvedalias((cast(2017-12-24 as double) + interval 2 months 2 seconds), None)] +- OneRowRelation spark-sql> ``` After this PR: ``` spark-sql> SELECT '2017-12-24' + interval 2 months 2 seconds; 2018-02-24 00:00:02 Time taken: 0.2 seconds, Fetched 1 row(s) ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #20067 from wangyum/SPARK-22894. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 6 ++++-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 ++++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2f306f58b7b8..1c4be547739d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -324,9 +324,11 @@ object TypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right) => + case a @ BinaryArithmetic(left @ StringType(), right) + if right.dataType != CalendarIntervalType => a.makeCopy(Array(Cast(left, DoubleType), right)) - case a @ BinaryArithmetic(left, right @ StringType()) => + case a @ BinaryArithmetic(left, right @ StringType()) + if left.dataType != CalendarIntervalType => a.makeCopy(Array(left, Cast(right, DoubleType))) // For equality between string and timestamp we cast the string to a timestamp diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade5..1972dec7b86c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.net.{MalformedURLException, URL} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} @@ -2760,6 +2760,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-22894: DateTimeOperations should accept SQL like string type") { + val date = "2017-12-24" + val str = sql(s"SELECT CAST('$date' as STRING) + interval 2 months 2 seconds") + val dt = sql(s"SELECT CAST('$date' as DATE) + interval 2 months 2 seconds") + val ts = sql(s"SELECT CAST('$date' as TIMESTAMP) + interval 2 months 2 seconds") + + checkAnswer(str, Row("2018-02-24 00:00:02") :: Nil) + checkAnswer(dt, Row(Date.valueOf("2018-02-24")) :: Nil) + checkAnswer(ts, Row(Timestamp.valueOf("2018-02-24 00:00:02")) :: Nil) + } + // Only New OrcFileFormat supports this Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, "parquet").foreach { format => From 6674acd1edc8c657049113df95af3378b5db8806 Mon Sep 17 00:00:00 2001 From: "xu.wenchun" Date: Wed, 27 Dec 2017 10:08:32 +0800 Subject: [PATCH 016/100] [SPARK-22846][SQL] Fix table owner is null when creating table through spark sql or thriftserver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? fix table owner is null when create new table through spark sql ## How was this patch tested? manual test. 1、first create a table 2、then select the table properties from mysql which connected to hive metastore Please review http://spark.apache.org/contributing.html before opening a pull request. Author: xu.wenchun Closes #20034 from BruceXu1991/SPARK-22846. --- .../scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7233944dc96d..7b7f4e0f1021 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -186,7 +186,7 @@ private[hive] class HiveClientImpl( /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf - private val userName = state.getAuthenticator.getUserName + private val userName = conf.getUser override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) From b8bfce51abf28c66ba1fc67b0f25fe1617c81025 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Dec 2017 20:51:26 +0900 Subject: [PATCH 017/100] [SPARK-22324][SQL][PYTHON][FOLLOW-UP] Update setup.py file. ## What changes were proposed in this pull request? This is a follow-up pr of #19884 updating setup.py file to add pyarrow dependency. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20089 from ueshin/issues/SPARK-22324/fup1. --- python/README.md | 2 +- python/setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/README.md b/python/README.md index 84ec88141cb0..3f17fdb98a08 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas). +At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/setup.py b/python/setup.py index 310670e697a8..251d4526d4dd 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2'] + 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -210,6 +210,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) From 774715d5c73ab6d410208fa1675cd11166e03165 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 27 Dec 2017 23:53:10 +0800 Subject: [PATCH 018/100] [SPARK-22904][SQL] Add tests for decimal operations and string casts ## What changes were proposed in this pull request? Test coverage for arithmetic operations leading to: 1. Precision loss 2. Overflow Moreover, tests for casting bad string to other input types and for using bad string as operators of some functions. ## How was this patch tested? added tests Author: Marco Gaido Closes #20084 from mgaido91/SPARK-22904. --- .../native/decimalArithmeticOperations.sql | 33 +++ .../native/stringCastAndExpressions.sql | 57 ++++ .../decimalArithmeticOperations.sql.out | 82 ++++++ .../native/stringCastAndExpressions.sql.out | 261 ++++++++++++++++++ 4 files changed, 433 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql new file mode 100644 index 000000000000..c8e108ac2c45 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -0,0 +1,33 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b; + +-- division, remainder and pmod by 0 return NULL +select a / b from t; +select a % b from t; +select pmod(a, b) from t; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss return NULL +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql new file mode 100644 index 000000000000..f17adb56dee9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql @@ -0,0 +1,57 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 'aa' as a; + +-- casting to data types which are unable to represent the string input returns NULL +select cast(a as byte) from t; +select cast(a as short) from t; +select cast(a as int) from t; +select cast(a as long) from t; +select cast(a as float) from t; +select cast(a as double) from t; +select cast(a as decimal) from t; +select cast(a as boolean) from t; +select cast(a as timestamp) from t; +select cast(a as date) from t; +-- casting to binary works correctly +select cast(a as binary) from t; +-- casting to array, struct or map throws exception +select cast(a as array) from t; +select cast(a as struct) from t; +select cast(a as map) from t; + +-- all timestamp/date expressions return NULL if bad input strings are provided +select to_timestamp(a) from t; +select to_timestamp('2018-01-01', a) from t; +select to_unix_timestamp(a) from t; +select to_unix_timestamp('2018-01-01', a) from t; +select unix_timestamp(a) from t; +select unix_timestamp('2018-01-01', a) from t; +select from_unixtime(a) from t; +select from_unixtime('2018-01-01', a) from t; +select next_day(a, 'MO') from t; +select next_day('2018-01-01', a) from t; +select trunc(a, 'MM') from t; +select trunc('2018-01-01', a) from t; + +-- some functions return NULL if bad input is provided +select unhex('-123'); +select sha2(a, a) from t; +select get_json_object(a, a) from t; +select json_tuple(a, a) from t; +select from_json(a, 'a INT') from t; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out new file mode 100644 index 000000000000..ce02f6adc456 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -0,0 +1,82 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select a / b from t +-- !query 1 schema +struct<(CAST(a AS DECIMAL(2,1)) / CAST(b AS DECIMAL(2,1))):decimal(8,6)> +-- !query 1 output +NULL + + +-- !query 2 +select a % b from t +-- !query 2 schema +struct<(CAST(a AS DECIMAL(2,1)) % CAST(b AS DECIMAL(2,1))):decimal(1,1)> +-- !query 2 output +NULL + + +-- !query 3 +select pmod(a, b) from t +-- !query 3 schema +struct +-- !query 3 output +NULL + + +-- !query 4 +select (5e36 + 0.1) + 5e36 +-- !query 4 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 4 output +NULL + + +-- !query 5 +select (-4e36 - 0.1) - 7e36 +-- !query 5 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 5 output +NULL + + +-- !query 6 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 6 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 6 output +NULL + + +-- !query 7 +select 1e35 / 0.1 +-- !query 7 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +-- !query 7 output +NULL + + +-- !query 8 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 8 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 8 output +NULL + + +-- !query 9 +select 0.001 / 9876543210987654321098765432109876543.2 +-- !query 9 schema +struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +-- !query 9 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out new file mode 100644 index 000000000000..8ed282024441 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -0,0 +1,261 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 32 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 'aa' as a +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select cast(a as byte) from t +-- !query 1 schema +struct +-- !query 1 output +NULL + + +-- !query 2 +select cast(a as short) from t +-- !query 2 schema +struct +-- !query 2 output +NULL + + +-- !query 3 +select cast(a as int) from t +-- !query 3 schema +struct +-- !query 3 output +NULL + + +-- !query 4 +select cast(a as long) from t +-- !query 4 schema +struct +-- !query 4 output +NULL + + +-- !query 5 +select cast(a as float) from t +-- !query 5 schema +struct +-- !query 5 output +NULL + + +-- !query 6 +select cast(a as double) from t +-- !query 6 schema +struct +-- !query 6 output +NULL + + +-- !query 7 +select cast(a as decimal) from t +-- !query 7 schema +struct +-- !query 7 output +NULL + + +-- !query 8 +select cast(a as boolean) from t +-- !query 8 schema +struct +-- !query 8 output +NULL + + +-- !query 9 +select cast(a as timestamp) from t +-- !query 9 schema +struct +-- !query 9 output +NULL + + +-- !query 10 +select cast(a as date) from t +-- !query 10 schema +struct +-- !query 10 output +NULL + + +-- !query 11 +select cast(a as binary) from t +-- !query 11 schema +struct +-- !query 11 output +aa + + +-- !query 12 +select cast(a as array) from t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to array; line 1 pos 7 + + +-- !query 13 +select cast(a as struct) from t +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to struct; line 1 pos 7 + + +-- !query 14 +select cast(a as map) from t +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to map; line 1 pos 7 + + +-- !query 15 +select to_timestamp(a) from t +-- !query 15 schema +struct +-- !query 15 output +NULL + + +-- !query 16 +select to_timestamp('2018-01-01', a) from t +-- !query 16 schema +struct +-- !query 16 output +NULL + + +-- !query 17 +select to_unix_timestamp(a) from t +-- !query 17 schema +struct +-- !query 17 output +NULL + + +-- !query 18 +select to_unix_timestamp('2018-01-01', a) from t +-- !query 18 schema +struct +-- !query 18 output +NULL + + +-- !query 19 +select unix_timestamp(a) from t +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select unix_timestamp('2018-01-01', a) from t +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select from_unixtime(a) from t +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select from_unixtime('2018-01-01', a) from t +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +select next_day(a, 'MO') from t +-- !query 23 schema +struct +-- !query 23 output +NULL + + +-- !query 24 +select next_day('2018-01-01', a) from t +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select trunc(a, 'MM') from t +-- !query 25 schema +struct +-- !query 25 output +NULL + + +-- !query 26 +select trunc('2018-01-01', a) from t +-- !query 26 schema +struct +-- !query 26 output +NULL + + +-- !query 27 +select unhex('-123') +-- !query 27 schema +struct +-- !query 27 output +NULL + + +-- !query 28 +select sha2(a, a) from t +-- !query 28 schema +struct +-- !query 28 output +NULL + + +-- !query 29 +select get_json_object(a, a) from t +-- !query 29 schema +struct +-- !query 29 output +NULL + + +-- !query 30 +select json_tuple(a, a) from t +-- !query 30 schema +struct +-- !query 30 output +NULL + + +-- !query 31 +select from_json(a, 'a INT') from t +-- !query 31 schema +struct> +-- !query 31 output +NULL From 753793bc84df805e519cf59f6804ab26bd02d76e Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 27 Dec 2017 17:31:12 -0800 Subject: [PATCH 019/100] [SPARK-22899][ML][STREAMING] Fix OneVsRestModel transform on streaming data failed. ## What changes were proposed in this pull request? Fix OneVsRestModel transform on streaming data failed. ## How was this patch tested? UT will be added soon, once #19979 merged. (Need a helper test method there) Author: WeichenXu Closes #20077 from WeichenXu123/fix_ovs_model_transform. --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 3ab99b35ece2..f04fde2cbbca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -165,7 +165,7 @@ final class OneVsRestModel private[ml] ( val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. - val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val handlePersistence = !dataset.isStreaming && dataset.storageLevel == StorageLevel.NONE if (handlePersistence) { newDataset.persist(StorageLevel.MEMORY_AND_DISK) } From 5683984520cfe9e9acf49e47a84a56af155a8ad2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 28 Dec 2017 12:28:19 +0800 Subject: [PATCH 020/100] [SPARK-18016][SQL][FOLLOW-UP] Code Generation: Constant Pool Limit - reduce entries for mutable state ## What changes were proposed in this pull request? This PR addresses additional review comments in #19811 ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #20036 from kiszk/SPARK-18066-followup. --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../expressions/regexpExpressions.scala | 62 +++++++++---------- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 16 +++-- .../execution/basicPhysicalOperators.scala | 4 +- .../joins/BroadcastHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 7 ++- 8 files changed, 51 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d6eccadcfb63..2c714c228e6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -190,7 +190,7 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { @@ -352,7 +352,7 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStateInitCode.distinct + val initCodes = mutableStateInitCode.distinct.map(_ + "\n") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index fa5425c77ebb..f3e8f6de5897 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,9 +118,8 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - // inline mutable state since not many Like operations in a task val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) + v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -143,9 +142,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr)); - ${ev.value} = $pattern.matcher(${eval1}.toString()).matches(); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($escapeFunc($rightStr)); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) } @@ -194,9 +193,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - // inline mutable state since not many RLike operations in a task val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) + v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -219,9 +217,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($rightStr); - ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($rightStr); + ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) } @@ -338,25 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals($termLastRegex)) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); } - if (!$rep.equals(${termLastReplacementInUTF8})) { + if (!$rep.equals($termLastReplacementInUTF8)) { // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + $termLastReplacementInUTF8 = $rep.clone(); + $termLastReplacement = $termLastReplacementInUTF8.toString(); } - $classNameStringBuffer ${termResult} = new $classNameStringBuffer(); - java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); + $classNameStringBuffer $termResult = new $classNameStringBuffer(); + java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); - while (${matcher}.find()) { - ${matcher}.appendReplacement(${termResult}, ${termLastReplacement}); + while ($matcher.find()) { + $matcher.appendReplacement($termResult, $termLastReplacement); } - ${matcher}.appendTail(${termResult}); - ${ev.value} = UTF8String.fromString(${termResult}.toString()); - ${termResult} = null; + $matcher.appendTail($termResult); + ${ev.value} = UTF8String.fromString($termResult.toString()); + $termResult = null; $setEvNotNull """ }) @@ -425,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals($termLastRegex)) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); } - java.util.regex.Matcher ${matcher} = - ${termPattern}.matcher($subject.toString()); - if (${matcher}.find()) { - java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); - if (${matchResult}.group($idx) == null) { + java.util.regex.Matcher $matcher = + $termPattern.matcher($subject.toString()); + if ($matcher.find()) { + java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + if ($matchResult.group($idx) == null) { ${ev.value} = UTF8String.EMPTY_UTF8; } else { - ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + ${ev.value} = UTF8String.fromString($matchResult.group($idx)); } $setEvNotNull } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index daff3c49e751..ef1bb1c2a446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -138,7 +138,7 @@ case class SortExec( // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) - // inline mutable state since not many Sort operations in a task + // Inline mutable state since not many Sort operations in a task sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", v => s"$v = $thisPlan.createSorter();", forceInline = true) val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9e7008d1e0c3..065954559e48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -283,7 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { // Right now, InputAdapter is only used when there is one input RDD. - // inline mutable state since an inputAdaptor in a task + // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", forceInline = true) val row = ctx.freshName("row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b1af360d8509..9a6f1c6dfa6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -587,31 +587,35 @@ case class HashAggregateExec( fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) + // Inline mutable state since not many aggregation operations in a task fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap", - v => s"$v = new $fastHashMapClassName();") - ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter") + v => s"$v = new $fastHashMapClassName();", forceInline = true) + ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter", + forceInline = true) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) + // Inline mutable state since not many aggregation operations in a task fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap", v => s"$v = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());", + forceInline = true) ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - "fastHashMapIter") + "fastHashMapIter", forceInline = true) } } // Create a name for the iterator from the regular hash map. - // inline mutable state since not many aggregation operations in a task + // Inline mutable state since not many aggregation operations in a task val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, "mapIter", forceInline = true) // create hashMap val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", - v => s"$v = $thisPlan.createHashMap();") + v => s"$v = $thisPlan.createHashMap();", forceInline = true) sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter", forceInline = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 78137d3f97cf..a15a8d11aa2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -284,7 +284,7 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - // inline mutable state since not many Sample operations in a task + // Inline mutable state since not many Sample operations in a task val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", v => { val initSamplerFuncName = ctx.addNewFunction(initSampler, @@ -371,7 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - // inline mutable state since not many Range operations in a task + // Inline mutable state since not many Range operations in a task val taskContext = ctx.addMutableState("TaskContext", "taskContext", v => s"$v = TaskContext.get();", forceInline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index ee763e23415c..1918fcc5482d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -139,7 +139,7 @@ case class BroadcastHashJoinExec( // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", v => s""" | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 073730462a75..94405410cce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,7 +422,7 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) @@ -440,8 +440,9 @@ case class SortMergeJoinExec( val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold + // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);") + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -576,7 +577,7 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", v => s"$v = inputs[0];", forceInline = true) val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", From 32ec269d08313720aae3b47cce2f5e9c19702811 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 28 Dec 2017 12:35:17 +0800 Subject: [PATCH 021/100] [SPARK-22909][SS] Move Structured Streaming v2 APIs to streaming folder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR moves Structured Streaming v2 APIs to streaming folder as following: ``` sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming ├── ContinuousReadSupport.java ├── ContinuousWriteSupport.java ├── MicroBatchReadSupport.java ├── MicroBatchWriteSupport.java ├── reader │   ├── ContinuousDataReader.java │   ├── ContinuousReader.java │   ├── MicroBatchReader.java │   ├── Offset.java │   └── PartitionOffset.java └── writer └── ContinuousWriter.java ``` ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20093 from zsxwing/move. --- .../sources/v2/{ => streaming}/ContinuousReadSupport.java | 7 ++++--- .../sources/v2/{ => streaming}/ContinuousWriteSupport.java | 6 ++++-- .../sources/v2/{ => streaming}/MicroBatchReadSupport.java | 6 ++++-- .../sources/v2/{ => streaming}/MicroBatchWriteSupport.java | 4 +++- .../v2/{ => streaming}/reader/ContinuousDataReader.java | 6 ++---- .../v2/{ => streaming}/reader/ContinuousReader.java | 4 ++-- .../v2/{ => streaming}/reader/MicroBatchReader.java | 4 ++-- .../sql/sources/v2/{ => streaming}/reader/Offset.java | 2 +- .../sources/v2/{ => streaming}/reader/PartitionOffset.java | 2 +- .../v2/{ => streaming}/writer/ContinuousWriter.java | 5 ++++- .../execution/datasources/v2/DataSourceV2ScanExec.scala | 1 + .../sql/execution/datasources/v2/WriteToDataSourceV2.scala | 1 + .../sql/execution/streaming/MicroBatchExecution.scala | 2 +- .../spark/sql/execution/streaming/RateSourceProvider.scala | 3 ++- .../spark/sql/execution/streaming/RateStreamOffset.scala | 2 +- .../spark/sql/execution/streaming/StreamingRelation.scala | 3 ++- .../streaming/continuous/ContinuousDataSourceRDDIter.scala | 1 + .../streaming/continuous/ContinuousExecution.scala | 7 ++++--- .../streaming/continuous/ContinuousRateStreamSource.scala | 3 ++- .../execution/streaming/continuous/EpochCoordinator.scala | 5 +++-- .../execution/streaming/sources/RateStreamSourceV2.scala | 1 + .../spark/sql/execution/streaming/sources/memoryV2.scala | 4 +++- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 3 ++- .../apache/spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeqLogSuite.scala | 1 - .../spark/sql/execution/streaming/RateSourceSuite.scala | 1 - .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 3 ++- 27 files changed, 54 insertions(+), 35 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/ContinuousReadSupport.java (88%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/ContinuousWriteSupport.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/MicroBatchReadSupport.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/MicroBatchWriteSupport.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/ContinuousDataReader.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/ContinuousReader.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/MicroBatchReader.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/Offset.java (97%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/PartitionOffset.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/writer/ContinuousWriter.java (87%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java similarity index 88% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index ae4f85820649..8837bae6156b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; -import org.apache.spark.sql.sources.v2.reader.ContinuousReader; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java index 362d5f52b4d0..ec15e436672b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 442cad029d21..3c87a3db6824 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java index 63640779b955..b5e3e44cd633 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index 11b99a93f149..ca9a290e97a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -15,11 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.PartitionOffset; - -import java.io.IOException; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 34141d6cd85f..f0b205869ed6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.PartitionOffset; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import java.util.Optional; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index bd15c07d87f6..70ff75680603 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.Offset; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index ce1c48974205..517fdab16684 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; /** * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index 07826b668847..729a6123034f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; import java.io.Serializable; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java similarity index 87% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java index 618f47ed79ca..723395bd1e96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java @@ -15,9 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2.streaming.writer; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** * A {@link DataSourceV2Writer} for use with continuous stream processing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index e4fca1b10dfa..49c506bc560c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 1862da8892cb..f0bdf84bb7a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 20f9810faa5c..9a7a13fcc580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 3f85fa913f28..d02cf882b61a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamR import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 65d6d1893616..261d69bbd984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.reader.Offset { + extends v2.streaming.reader.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 0ca2e7854d94..a9d50e3a112e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 89fb2ace2091..d79e4bd65f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, Ro import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{SystemClock, ThreadUtils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1c35b06bd4b8..2843ab13bde2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,9 +29,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, ContinuousWriteSupport, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 89a8562b4b59..c9aa78a5a2e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} case class ContinuousRateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 7f1e8abd79b9..98017c3ac6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -26,8 +26,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.{ContinuousWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 1c66aed8690a..97bada08bcd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.SystemClock diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 972248d5e4df..da7c31cf6242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f17935e86f45..acd5ca17dcf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index e808ffaa9641..b508f4406138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 4868ba4e6893..e6cdc063c4e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala index ceba27b26e57..03d0f63fa4d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index dc833b2ccaa2..e11705a227f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport import org.apache.spark.sql.streaming.StreamTest class RateSourceV2Suite extends StreamTest { From 171f6ddadc6185ffcc6ad82e5f48952fb49095b2 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 28 Dec 2017 13:44:44 +0900 Subject: [PATCH 022/100] [SPARK-22757][KUBERNETES] Enable use of remote dependencies (http, s3, gcs, etc.) in Kubernetes mode ## What changes were proposed in this pull request? This PR expands the Kubernetes mode to be able to use remote dependencies on http/https endpoints, GCS, S3, etc. It adds steps for configuring and appending the Kubernetes init-container into the driver and executor pods for downloading remote dependencies. [Init-containers](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/), as the name suggests, are containers that are run to completion before the main containers start, and are often used to perform initialization tasks prior to starting the main containers. We use init-containers to localize remote application dependencies before the driver/executors start running. The code that the init-container runs is also included. This PR also adds a step to the driver and executors for mounting user-specified secrets that may store credentials for accessing data storage, e.g., S3 and Google Cloud Storage (GCS), into the driver and executors. ## How was this patch tested? * The patch contains unit tests which are passing. * Manual testing: `./build/mvn -Pkubernetes clean package` succeeded. * Manual testing of the following cases: * [x] Running SparkPi using container-local spark-example jar. * [x] Running SparkPi using container-local spark-example jar with user-specific secret mounted. * [x] Running SparkPi using spark-example jar hosted remotely on an https endpoint. cc rxin felixcheung mateiz (shepherd) k8s-big-data SIG members & contributors: mccheah foxish ash211 ssuchter varunkatta kimoonkim erikerlandson tnachen ifilonenko liyinan926 reviewers: vanzin felixcheung jiangxb1987 mridulm Author: Yinan Li Closes #19954 from liyinan926/init-container. --- .../org/apache/spark/deploy/k8s/Config.scala | 67 +++++++- .../apache/spark/deploy/k8s/Constants.scala | 11 ++ .../deploy/k8s/InitContainerBootstrap.scala | 119 +++++++++++++ ...sFileUtils.scala => KubernetesUtils.scala} | 50 +++++- .../deploy/k8s/MountSecretsBootstrap.scala | 62 +++++++ ...ala => PodWithDetachedInitContainer.scala} | 34 ++-- .../k8s/SparkKubernetesClientFactory.scala | 2 +- ...r.scala => DriverConfigOrchestrator.scala} | 91 +++++++--- .../submit/KubernetesClientApplication.scala | 26 ++- ...ala => BasicDriverConfigurationStep.scala} | 63 ++++--- .../steps/DependencyResolutionStep.scala | 27 +-- .../steps/DriverConfigurationStep.scala | 2 +- .../DriverInitContainerBootstrapStep.scala | 95 +++++++++++ .../submit/steps/DriverMountSecretsStep.scala | 38 +++++ .../steps/DriverServiceBootstrapStep.scala | 17 +- .../BasicInitContainerConfigurationStep.scala | 67 ++++++++ .../InitContainerConfigOrchestrator.scala | 79 +++++++++ .../InitContainerConfigurationStep.scala | 25 +++ .../InitContainerMountSecretsStep.scala | 39 +++++ .../initcontainer/InitContainerSpec.scala | 37 ++++ .../rest/k8s/SparkPodInitContainer.scala | 116 +++++++++++++ .../cluster/k8s/ExecutorPodFactory.scala | 77 ++++++--- .../k8s/KubernetesClusterManager.scala | 62 ++++++- ...la => DriverConfigOrchestratorSuite.scala} | 69 ++++++-- .../deploy/k8s/submit/SecretVolumeUtils.scala | 36 ++++ ...> BasicDriverConfigurationStepSuite.scala} | 4 +- ...riverInitContainerBootstrapStepSuite.scala | 160 ++++++++++++++++++ .../steps/DriverMountSecretsStepSuite.scala | 49 ++++++ ...cInitContainerConfigurationStepSuite.scala | 95 +++++++++++ ...InitContainerConfigOrchestratorSuite.scala | 80 +++++++++ .../InitContainerMountSecretsStepSuite.scala | 57 +++++++ .../rest/k8s/SparkPodInitContainerSuite.scala | 86 ++++++++++ .../cluster/k8s/ExecutorPodFactorySuite.scala | 82 ++++++++- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 24 +++ 36 files changed, 1783 insertions(+), 171 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{submit/KubernetesFileUtils.scala => KubernetesUtils.scala} (57%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{ConfigurationUtils.scala => PodWithDetachedInitContainer.scala} (53%) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/{DriverConfigurationStepsOrchestrator.scala => DriverConfigOrchestrator.scala} (53%) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/{BaseDriverConfigurationStep.scala => BasicDriverConfigurationStep.scala} (70%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/{DriverConfigurationStepsOrchestratorSuite.scala => DriverConfigOrchestratorSuite.scala} (51%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/{BaseDriverConfigurationStepSuite.scala => BasicDriverConfigurationStepSuite.scala} (97%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 45f527959cbe..e5d79d9a9d9d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -20,7 +20,6 @@ import java.util.concurrent.TimeUnit import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigBuilder -import org.apache.spark.network.util.ByteUnit private[spark] object Config extends Logging { @@ -132,20 +131,72 @@ private[spark] object Config extends Logging { val JARS_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pod.") + .doc("Location to download jars to in the driver and executors. When using " + + "spark-submit, this directory must be empty and will be mounted as an empty directory " + + "volume on the driver and executor pod.") .stringConf .createWithDefault("/var/spark-data/spark-jars") val FILES_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pods.") + .doc("Location to download files to in the driver and executors. When using " + + "spark-submit, this directory must be empty and will be mounted as an empty directory " + + "volume on the driver and executor pods.") .stringConf .createWithDefault("/var/spark-data/spark-files") + val INIT_CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.initContainer.image") + .doc("Image for the driver and executor's init-container for downloading dependencies.") + .stringConf + .createOptional + + val INIT_CONTAINER_MOUNT_TIMEOUT = + ConfigBuilder("spark.kubernetes.mountDependencies.timeout") + .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + + "locations into the driver and executor pods.") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(300) + + val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = + ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") + .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + + "executor pod.") + .intConf + .createWithDefault(5) + + val INIT_CONTAINER_REMOTE_JARS = + ConfigBuilder("spark.kubernetes.initContainer.remoteJars") + .doc("Comma-separated list of jar URIs to download in the init-container. This is " + + "calculated from spark.jars.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_REMOTE_FILES = + ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") + .doc("Comma-separated list of file URIs to download in the init-container. This is " + + "calculated from spark.files.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_CONFIG_MAP_NAME = + ConfigBuilder("spark.kubernetes.initContainer.configMapName") + .doc("Name of the config map to use in the init-container that retrieves submitted files " + + "for the executor.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = + ConfigBuilder("spark.kubernetes.initContainer.configMapKey") + .doc("Key for the entry in the init container config map for submitted files that " + + "corresponds to the properties for this init-container.") + .internal() + .stringConf + .createOptional + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -153,9 +204,11 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 0b91145405d3..111cb2a3b75e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -69,6 +69,17 @@ private[spark] object Constants { val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" + val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" + + // Bootstrapping dependencies with the init-container + val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" + val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" + val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" + val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" + val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" + val INIT_CONTAINER_PROPERTIES_FILE_PATH = + s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" + val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala new file mode 100644 index 000000000000..dfeccf9e2bd1 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Bootstraps an init-container for downloading remote dependencies. This is separated out from + * the init-container steps API because this component can be used to bootstrap init-containers + * for both the driver and executors. + */ +private[spark] class InitContainerBootstrap( + initContainerImage: String, + imagePullPolicy: String, + jarsDownloadPath: String, + filesDownloadPath: String, + configMapName: String, + configMapKey: String, + sparkRole: String, + sparkConf: SparkConf) { + + /** + * Bootstraps an init-container that downloads dependencies to be used by a main container. + */ + def bootstrapInitContainer( + original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { + val sharedVolumeMounts = Seq[VolumeMount]( + new VolumeMountBuilder() + .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) + .withMountPath(jarsDownloadPath) + .build(), + new VolumeMountBuilder() + .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) + .withMountPath(filesDownloadPath) + .build()) + + val customEnvVarKeyPrefix = sparkRole match { + case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY + case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." + case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") + } + val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { + case (key, value) => + new EnvVarBuilder() + .withName(key) + .withValue(value) + .build() + } + + val initContainer = new ContainerBuilder(original.initContainer) + .withName("spark-init") + .withImage(initContainerImage) + .withImagePullPolicy(imagePullPolicy) + .addAllToEnv(customEnvVars.asJava) + .addNewVolumeMount() + .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) + .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) + .endVolumeMount() + .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) + .build() + + val podWithBasicVolumes = new PodBuilder(original.pod) + .editSpec() + .addNewVolume() + .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .addNewItem() + .withKey(configMapKey) + .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) + .endItem() + .endConfigMap() + .endVolume() + .addNewVolume() + .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) + .withEmptyDir(new EmptyDirVolumeSource()) + .endVolume() + .addNewVolume() + .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) + .withEmptyDir(new EmptyDirVolumeSource()) + .endVolume() + .endSpec() + .build() + + val mainContainer = new ContainerBuilder(original.mainContainer) + .addToVolumeMounts(sharedVolumeMounts: _*) + .addNewEnv() + .withName(ENV_MOUNTED_FILES_DIR) + .withValue(filesDownloadPath) + .endEnv() + .build() + + PodWithDetachedInitContainer( + podWithBasicVolumes, + initContainer, + mainContainer) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala similarity index 57% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index a38cf55fc3d5..37331d8bbf9b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -14,13 +14,49 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import java.io.File +import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} + +import org.apache.spark.SparkConf import org.apache.spark.util.Utils -private[spark] object KubernetesFileUtils { +private[spark] object KubernetesUtils { + + /** + * Extract and parse Spark configuration properties with a given name prefix and + * return the result as a Map. Keys must not have more than one value. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing the configuration property keys and values + */ + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String): Map[String, String] = { + sparkConf.getAllWithPrefix(prefix).toMap + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } + + /** + * Append the given init-container to a pod's list of init-containers. + * + * @param originalPodSpec original specification of the pod + * @param initContainer the init-container to add to the pod + * @return the pod with the init-container added to the list of InitContainers + */ + def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { + new PodBuilder(originalPodSpec) + .editOrNewSpec() + .addToInitContainers(initContainer) + .endSpec() + .build() + } /** * For the given collection of file URIs, resolves them as follows: @@ -47,6 +83,16 @@ private[spark] object KubernetesFileUtils { } } + /** + * Get from a given collection of file URIs the ones that represent remote files. + */ + def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { + uris.filter { uri => + val scheme = Utils.resolveURI(uri).getScheme + scheme != "file" && scheme != "local" + } + } + private def resolveFileUri( uri: String, fileDownloadPath: String, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala new file mode 100644 index 000000000000..8286546ce064 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} + +/** + * Bootstraps a driver or executor container or an init-container with needed secrets mounted. + */ +private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { + + /** + * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * + * @param pod the pod into which the secret volumes are being added. + * @param container the container into which the secret volumes are being mounted. + * @return the updated pod and container with the secrets mounted. + */ + def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + var podBuilder = new PodBuilder(pod) + secretNamesToMountPaths.keys.foreach { name => + podBuilder = podBuilder + .editOrNewSpec() + .addNewVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() + .endSpec() + } + + var containerBuilder = new ContainerBuilder(container) + secretNamesToMountPaths.foreach { case (name, path) => + containerBuilder = containerBuilder + .addNewVolumeMount() + .withName(secretVolumeName(name)) + .withMountPath(path) + .endVolumeMount() + } + + (podBuilder.build(), containerBuilder.build()) + } + + private def secretVolumeName(secretName: String): String = { + secretName + "-volume" + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala similarity index 53% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala index 01717479fddd..0b79f8b12e80 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala @@ -14,28 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.deploy.k8s -import org.apache.spark.SparkConf - -private[spark] object ConfigurationUtils { - - /** - * Extract and parse Spark configuration properties with a given name prefix and - * return the result as a Map. Keys must not have more than one value. - * - * @param sparkConf Spark configuration - * @param prefix the given property name prefix - * @return a Map storing the configuration property keys and values - */ - def parsePrefixedKeyValuePairs( - sparkConf: SparkConf, - prefix: String): Map[String, String] = { - sparkConf.getAllWithPrefix(prefix).toMap - } +import io.fabric8.kubernetes.api.model.{Container, Pod} - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { - opt1.foreach { _ => require(opt2.isEmpty, errMessage) } - } -} +/** + * Represents a pod with a detached init-container (not yet added to the pod). + * + * @param pod the pod + * @param initContainer the init-container in the pod + * @param mainContainer the main container in the pod + */ +private[spark] case class PodWithDetachedInitContainer( + pod: Pod, + initContainer: Container, + mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 1e3f055e0576..c47e78cbf19e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -48,7 +48,7 @@ private[spark] object SparkKubernetesClientFactory { .map(new File(_)) .orElse(defaultServiceAccountToken) val oauthTokenValue = sparkConf.getOption(oauthTokenConf) - ConfigurationUtils.requireNandDefined( + KubernetesUtils.requireNandDefined( oauthTokenFile, oauthTokenValue, s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala similarity index 53% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 1411e6f40b46..00c9c4ee4917 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -21,25 +21,31 @@ import java.util.UUID import com.google.common.primitives.Longs import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock +import org.apache.spark.util.Utils /** - * Constructs the complete list of driver configuration steps to run to deploy the Spark driver. + * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to + * configure the Spark driver pod. The returned steps will be applied one by one in the given + * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication + * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to + * configure the driver init-container if one is needed, i.e., when there are remote dependencies + * to localize. */ -private[spark] class DriverConfigurationStepsOrchestrator( - namespace: String, +private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, launchTime: Long, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, appArgs: Array[String], - submissionSparkConf: SparkConf) { + sparkConf: SparkConf) { // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the @@ -49,13 +55,14 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") } - private val imagePullPolicy = submissionSparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val jarsDownloadPath = submissionSparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = submissionSparkConf.get(FILES_DOWNLOAD_LOCATION) + private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" + private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) + private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) - def getAllConfigurationSteps(): Seq[DriverConfigurationStep] = { - val driverCustomLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( - submissionSparkConf, + def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + @@ -64,11 +71,15 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + "operations.") + val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_DRIVER_SECRETS_PREFIX) + val allDriverLabels = driverCustomLabels ++ Map( SPARK_APP_ID_LABEL -> kubernetesAppId, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - val initialSubmissionStep = new BaseDriverConfigurationStep( + val initialSubmissionStep = new BasicDriverConfigurationStep( kubernetesAppId, kubernetesResourceNamePrefix, allDriverLabels, @@ -76,16 +87,16 @@ private[spark] class DriverConfigurationStepsOrchestrator( appName, mainClass, appArgs, - submissionSparkConf) + sparkConf) - val driverAddressStep = new DriverServiceBootstrapStep( + val serviceBootstrapStep = new DriverServiceBootstrapStep( kubernetesResourceNamePrefix, allDriverLabels, - submissionSparkConf, + sparkConf, new SystemClock) val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, kubernetesResourceNamePrefix) + sparkConf, kubernetesResourceNamePrefix) val additionalMainAppJar = if (mainAppResource.nonEmpty) { val mayBeResource = mainAppResource.get match { @@ -98,28 +109,62 @@ private[spark] class DriverConfigurationStepsOrchestrator( None } - val sparkJars = submissionSparkConf.getOption("spark.jars") + val sparkJars = sparkConf.getOption("spark.jars") .map(_.split(",")) .getOrElse(Array.empty[String]) ++ additionalMainAppJar.toSeq - val sparkFiles = submissionSparkConf.getOption("spark.files") + val sparkFiles = sparkConf.getOption("spark.files") .map(_.split(",")) .getOrElse(Array.empty[String]) - val maybeDependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Some(new DependencyResolutionStep( + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { + Seq(new DependencyResolutionStep( sparkJars, sparkFiles, jarsDownloadPath, filesDownloadPath)) } else { - None + Nil + } + + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { + val orchestrator = new InitContainerConfigOrchestrator( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath, + imagePullPolicy, + initContainerConfigMapName, + INIT_CONTAINER_PROPERTIES_FILE_NAME, + sparkConf) + val bootstrapStep = new DriverInitContainerBootstrapStep( + orchestrator.getAllConfigurationSteps, + initContainerConfigMapName, + INIT_CONTAINER_PROPERTIES_FILE_NAME) + + Seq(bootstrapStep) + } else { + Nil + } + + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil } Seq( initialSubmissionStep, - driverAddressStep, + serviceBootstrapStep, kubernetesCredentialsStep) ++ - maybeDependencyResolutionStep.toSeq + dependencyResolutionStep ++ + initContainerBootstrapStep ++ + mountSecretsStep + } + + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme != "local" + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 240a1144577b..5884348cb3e4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -80,22 +80,22 @@ private[spark] object ClientArguments { * spark.kubernetes.submission.waitAppCompletion is true. * * @param submissionSteps steps that collectively configure the driver - * @param submissionSparkConf the submission client Spark configuration + * @param sparkConf the submission client Spark configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete * @param appName the application name - * @param loggingPodStatusWatcher a watcher that monitors and logs the application status + * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( submissionSteps: Seq[DriverConfigurationStep], - submissionSparkConf: SparkConf, + sparkConf: SparkConf, kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - loggingPodStatusWatcher: LoggingPodStatusWatcher) extends Logging { + watcher: LoggingPodStatusWatcher) extends Logging { - private val driverJavaOptions = submissionSparkConf.get( + private val driverJavaOptions = sparkConf.get( org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) /** @@ -104,7 +104,7 @@ private[spark] class Client( * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(submissionSparkConf) + var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) // submissionSteps contain steps necessary to take, to resolve varying // client arguments that are passed in, created by orchestrator for (nextStep <- submissionSteps) { @@ -141,7 +141,7 @@ private[spark] class Client( kubernetesClient .pods() .withName(resolvedDriverPod.getMetadata.getName) - .watch(loggingPodStatusWatcher)) { _ => + .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { @@ -157,7 +157,7 @@ private[spark] class Client( if (waitForAppCompletion) { logInfo(s"Waiting for application $appName to finish...") - loggingPodStatusWatcher.awaitCompletion() + watcher.awaitCompletion() logInfo(s"Application $appName finished.") } else { logInfo(s"Deployed Spark application $appName into Kubernetes.") @@ -207,11 +207,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val master = sparkConf.get("spark.master").substring("k8s://".length) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None - val loggingPodStatusWatcher = new LoggingPodStatusWatcherImpl( - kubernetesAppId, loggingInterval) + val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val configurationStepsOrchestrator = new DriverConfigurationStepsOrchestrator( - namespace, + val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, launchTime, clientArguments.mainAppResource, @@ -228,12 +226,12 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - configurationStepsOrchestrator.getAllConfigurationSteps(), + orchestrator.getAllConfigurationSteps, sparkConf, kubernetesClient, waitForAppCompletion, appName, - loggingPodStatusWatcher) + watcher) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala similarity index 70% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index c335fcce4036..b7a69a7dfd47 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -22,49 +22,46 @@ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarS import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} /** - * Represents the initial setup required for the driver. + * Performs basic configuration for the driver pod. */ -private[spark] class BaseDriverConfigurationStep( +private[spark] class BasicDriverConfigurationStep( kubernetesAppId: String, - kubernetesResourceNamePrefix: String, + resourceNamePrefix: String, driverLabels: Map[String, String], imagePullPolicy: String, appName: String, mainClass: String, appArgs: Array[String], - submissionSparkConf: SparkConf) extends DriverConfigurationStep { + sparkConf: SparkConf) extends DriverConfigurationStep { - private val kubernetesDriverPodName = submissionSparkConf.get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$kubernetesResourceNamePrefix-driver") + private val driverPodName = sparkConf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"$resourceNamePrefix-driver") - private val driverExtraClasspath = submissionSparkConf.get( - DRIVER_CLASS_PATH) + private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - private val driverContainerImage = submissionSparkConf + private val driverContainerImage = sparkConf .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) // CPU settings - private val driverCpuCores = submissionSparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = submissionSparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) + private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") + private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) // Memory settings - private val driverMemoryMiB = submissionSparkConf.get( - DRIVER_MEMORY) - private val driverMemoryString = submissionSparkConf.get( - DRIVER_MEMORY.key, - DRIVER_MEMORY.defaultValueString) - private val memoryOverheadMiB = submissionSparkConf + private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) + private val driverMemoryString = sparkConf.get( + DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) + private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val driverContainerMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => @@ -74,15 +71,13 @@ private[spark] class BaseDriverConfigurationStep( .build() } - val driverCustomAnnotations = ConfigurationUtils - .parsePrefixedKeyValuePairs( - submissionSparkConf, - KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + " Spark bookkeeping operations.") - val driverCustomEnvs = submissionSparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq + val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq .map { env => new EnvVarBuilder() .withName(env._1) @@ -90,10 +85,10 @@ private[spark] class BaseDriverConfigurationStep( .build() } - val allDriverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) + val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - val nodeSelector = ConfigurationUtils.parsePrefixedKeyValuePairs( - submissionSparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) val driverCpuQuantity = new QuantityBuilder(false) .withAmount(driverCpuCores) @@ -102,7 +97,7 @@ private[spark] class BaseDriverConfigurationStep( .withAmount(s"${driverMemoryMiB}Mi") .build() val driverMemoryLimitQuantity = new QuantityBuilder(false) - .withAmount(s"${driverContainerMemoryWithOverheadMiB}Mi") + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) @@ -142,9 +137,9 @@ private[spark] class BaseDriverConfigurationStep( val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() - .withName(kubernetesDriverPodName) + .withName(driverPodName) .addToLabels(driverLabels.asJava) - .addToAnnotations(allDriverAnnotations.asJava) + .addToAnnotations(driverAnnotations.asJava) .endMetadata() .withNewSpec() .withRestartPolicy("Never") @@ -153,9 +148,9 @@ private[spark] class BaseDriverConfigurationStep( .build() val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, kubernetesDriverPodName) + .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, kubernetesResourceNamePrefix) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) driverSpec.copy( driverPod = baseDriverPod, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index 44e0ecffc0e9..d4b83235b4e3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -21,7 +21,8 @@ import java.io.File import io.fabric8.kubernetes.api.model.ContainerBuilder import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, KubernetesFileUtils} +import org.apache.spark.deploy.k8s.KubernetesUtils +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec /** * Step that configures the classpath, spark.jars, and spark.files for the driver given that the @@ -31,21 +32,22 @@ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], sparkFiles: Seq[String], jarsDownloadPath: String, - localFilesDownloadPath: String) extends DriverConfigurationStep { + filesDownloadPath: String) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesFileUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesFileUtils.resolveFileUris( - sparkFiles, localFilesDownloadPath) - val sparkConfResolvedSparkDependencies = driverSpec.driverSparkConf.clone() + val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) + val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + + val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { - sparkConfResolvedSparkDependencies.set("spark.jars", resolvedSparkJars.mkString(",")) + sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) } if (resolvedSparkFiles.nonEmpty) { - sparkConfResolvedSparkDependencies.set("spark.files", resolvedSparkFiles.mkString(",")) + sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - val resolvedClasspath = KubernetesFileUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val driverContainerWithResolvedClasspath = if (resolvedClasspath.nonEmpty) { + + val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) + val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() .withName(ENV_MOUNTED_CLASSPATH) @@ -55,8 +57,9 @@ private[spark] class DependencyResolutionStep( } else { driverSpec.driverContainer } + driverSpec.copy( - driverContainer = driverContainerWithResolvedClasspath, - driverSparkConf = sparkConfResolvedSparkDependencies) + driverContainer = resolvedDriverContainer, + driverSparkConf = sparkConf) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala index c99c0436cf25..17614e040e58 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec /** - * Represents a step in preparing the Kubernetes driver. + * Represents a step in configuring the Spark driver pod. */ private[spark] trait DriverConfigurationStep { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala new file mode 100644 index 000000000000..9fb3dafdda54 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.StringWriter +import java.util.Properties + +import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} + +/** + * Configures the driver init-container that localizes remote dependencies into the driver pod. + * It applies the given InitContainerConfigurationSteps in the given order to produce a final + * InitContainerSpec that is then used to configure the driver pod with the init-container attached. + * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries + * configuration properties for the init-container. + */ +private[spark] class DriverInitContainerBootstrapStep( + steps: Seq[InitContainerConfigurationStep], + configMapName: String, + configMapKey: String) + extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + var initContainerSpec = InitContainerSpec( + properties = Map.empty[String, String], + driverSparkConf = Map.empty[String, String], + initContainer = new ContainerBuilder().build(), + driverContainer = driverSpec.driverContainer, + driverPod = driverSpec.driverPod, + dependentResources = Seq.empty[HasMetadata]) + for (nextStep <- steps) { + initContainerSpec = nextStep.configureInitContainer(initContainerSpec) + } + + val configMap = buildConfigMap( + configMapName, + configMapKey, + initContainerSpec.properties) + val resolvedDriverSparkConf = driverSpec.driverSparkConf + .clone() + .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) + .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) + .setAll(initContainerSpec.driverSparkConf) + val resolvedDriverPod = KubernetesUtils.appendInitContainer( + initContainerSpec.driverPod, initContainerSpec.initContainer) + + driverSpec.copy( + driverPod = resolvedDriverPod, + driverContainer = initContainerSpec.driverContainer, + driverSparkConf = resolvedDriverSparkConf, + otherKubernetesResources = + driverSpec.otherKubernetesResources ++ + initContainerSpec.dependentResources ++ + Seq(configMap)) + } + + private def buildConfigMap( + configMapName: String, + configMapKey: String, + config: Map[String, String]): ConfigMap = { + val properties = new Properties() + config.foreach { entry => + properties.setProperty(entry._1, entry._2) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName " + + s"and config map key: $configMapKey") + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .endMetadata() + .addToData(configMapKey, propertiesWriter.toString) + .build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala new file mode 100644 index 000000000000..f872e0f4b65d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +/** + * A driver configuration step for mounting user-specified secrets onto user-specified paths. + * + * @param bootstrap a utility actually handling mounting of the secrets. + */ +private[spark] class DriverMountSecretsStep( + bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val (pod, container) = bootstrap.mountSecrets( + driverSpec.driverPod, driverSpec.driverContainer) + driverSpec.copy( + driverPod = pod, + driverContainer = container + ) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index 696d11f15ed9..eb594e4f16ec 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -32,21 +32,22 @@ import org.apache.spark.util.Clock * ports should correspond to the ports that the executor will reach the pod at for RPC. */ private[spark] class DriverServiceBootstrapStep( - kubernetesResourceNamePrefix: String, + resourceNamePrefix: String, driverLabels: Map[String, String], - submissionSparkConf: SparkConf, + sparkConf: SparkConf, clock: Clock) extends DriverConfigurationStep with Logging { + import DriverServiceBootstrapStep._ override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(submissionSparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + "address is managed and set to the driver pod's IP address.") - require(submissionSparkConf.getOption(DRIVER_HOST_KEY).isEmpty, + require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + "managed via a Kubernetes service.") - val preferredServiceName = s"$kubernetesResourceNamePrefix$DRIVER_SVC_POSTFIX" + val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { preferredServiceName } else { @@ -58,8 +59,8 @@ private[spark] class DriverServiceBootstrapStep( shorterServiceName } - val driverPort = submissionSparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = submissionSparkConf.getInt( + val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = sparkConf.getInt( org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) val driverService = new ServiceBuilder() .withNewMetadata() @@ -81,7 +82,7 @@ private[spark] class DriverServiceBootstrapStep( .endSpec() .build() - val namespace = submissionSparkConf.get(KUBERNETES_NAMESPACE) + val namespace = sparkConf.get(KUBERNETES_NAMESPACE) val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala new file mode 100644 index 000000000000..01469853dacc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils + +/** + * Performs basic configuration for the driver init-container with most of the work delegated to + * the given InitContainerBootstrap. + */ +private[spark] class BasicInitContainerConfigurationStep( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + filesDownloadPath: String, + bootstrap: InitContainerBootstrap) + extends InitContainerConfigurationStep { + + override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { + val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) + val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) + val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { + Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) + } else { + Map() + } + val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { + Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) + } else { + Map() + } + + val baseInitContainerConfig = Map( + JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, + FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ + remoteJarsConf ++ + remoteFilesConf + + val bootstrapped = bootstrap.bootstrapInitContainer( + PodWithDetachedInitContainer( + spec.driverPod, + spec.initContainer, + spec.driverContainer)) + + spec.copy( + initContainer = bootstrapped.initContainer, + driverContainer = bootstrapped.mainContainer, + driverPod = bootstrapped.pod, + properties = spec.properties ++ baseInitContainerConfig) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala new file mode 100644 index 000000000000..f2c29c7ce107 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to + * configure the driver init-container. The returned steps will be applied in the given order to + * produce a final InitContainerSpec that is used to construct the driver init-container in + * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., + * when there are remote application dependencies to localize. + */ +private[spark] class InitContainerConfigOrchestrator( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + filesDownloadPath: String, + imagePullPolicy: String, + configMapName: String, + configMapKey: String, + sparkConf: SparkConf) { + + private val initContainerImage = sparkConf + .get(INIT_CONTAINER_IMAGE) + .getOrElse(throw new SparkException( + "Must specify the init-container image when there are remote dependencies")) + + def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { + val initContainerBootstrap = new InitContainerBootstrap( + initContainerImage, + imagePullPolicy, + jarsDownloadPath, + filesDownloadPath, + configMapName, + configMapKey, + SPARK_POD_DRIVER_ROLE, + sparkConf) + val baseStep = new BasicInitContainerConfigurationStep( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath, + initContainerBootstrap) + + val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_DRIVER_SECRETS_PREFIX) + // Mount user-specified driver secrets also into the driver's init-container. The + // init-container may need credentials in the secrets to be able to download remote + // dependencies. The driver's main container and its init-container share the secrets + // because the init-container is sort of an implementation details and this sharing + // avoids introducing a dedicated configuration property just for the init-container. + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + + Seq(baseStep) ++ mountSecretsStep + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala new file mode 100644 index 000000000000..0372ad527095 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +/** + * Represents a step in configuring the driver init-container. + */ +private[spark] trait InitContainerConfigurationStep { + + def configureInitContainer(spec: InitContainerSpec): InitContainerSpec +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala new file mode 100644 index 000000000000..c0e7bb20cce8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.deploy.k8s.MountSecretsBootstrap + +/** + * An init-container configuration step for mounting user-specified secrets onto user-specified + * paths. + * + * @param bootstrap a utility actually handling mounting of the secrets + */ +private[spark] class InitContainerMountSecretsStep( + bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { + + override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { + val (driverPod, initContainer) = bootstrap.mountSecrets( + spec.driverPod, + spec.initContainer) + spec.copy( + driverPod = driverPod, + initContainer = initContainer + ) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala new file mode 100644 index 000000000000..b52c343f0c0e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} + +/** + * Represents a specification of the init-container for the driver pod. + * + * @param properties properties that should be set on the init-container + * @param driverSparkConf Spark configuration properties that will be carried back to the driver + * @param initContainer the init-container object + * @param driverContainer the driver container object + * @param driverPod the driver pod object + * @param dependentResources resources the init-container depends on to work + */ +private[spark] case class InitContainerSpec( + properties: Map[String, String], + driverSparkConf: Map[String, String], + initContainer: Container, + driverContainer: Container, + driverPod: Pod, + dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala new file mode 100644 index 000000000000..4a4b628aedbb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.rest.k8s + +import java.io.File +import java.util.concurrent.TimeUnit + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Process that fetches files from a resource staging server and/or arbitrary remote locations. + * + * The init-container can handle fetching files from any of those sources, but not all of the + * sources need to be specified. This allows for composing multiple instances of this container + * with different configurations for different download sources, or using the same container to + * download everything at once. + */ +private[spark] class SparkPodInitContainer( + sparkConf: SparkConf, + fileFetcher: FileFetcher) extends Logging { + + private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) + private implicit val downloadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) + + private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) + private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) + + private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) + private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) + + private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) + + def run(): Unit = { + logInfo(s"Downloading remote jars: $remoteJars") + downloadFiles( + remoteJars, + jarsDownloadDir, + s"Remote jars download directory specified at $jarsDownloadDir does not exist " + + "or is not a directory.") + + logInfo(s"Downloading remote files: $remoteFiles") + downloadFiles( + remoteFiles, + filesDownloadDir, + s"Remote files download directory specified at $filesDownloadDir does not exist " + + "or is not a directory.") + + downloadExecutor.shutdown() + downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) + } + + private def downloadFiles( + filesCommaSeparated: Option[String], + downloadDir: File, + errMessage: String): Unit = { + filesCommaSeparated.foreach { files => + require(downloadDir.isDirectory, errMessage) + Utils.stringToSeq(files).foreach { file => + Future[Unit] { + fileFetcher.fetchFile(file, downloadDir) + } + } + } + } +} + +private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { + + def fetchFile(uri: String, targetDir: File): Unit = { + Utils.fetchFile( + url = uri, + targetDir = targetDir, + conf = sparkConf, + securityMgr = securityManager, + hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), + timestamp = System.currentTimeMillis(), + useCache = false) + } +} + +object SparkPodInitContainer extends Logging { + + def main(args: Array[String]): Unit = { + logInfo("Starting init-container to download Spark application dependencies.") + val sparkConf = new SparkConf(true) + if (args.nonEmpty) { + Utils.loadDefaultSparkProperties(sparkConf, args(0)) + } + + val securityManager = new SparkSecurityManager(sparkConf) + val fileFetcher = new FileFetcher(sparkConf, securityManager) + new SparkPodInitContainer(sparkConf, fileFetcher).run() + logInfo("Finished downloading application dependencies.") + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 70226157dd68..ba5d891f4c77 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,35 +21,35 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} import org.apache.spark.util.Utils /** - * A factory class for configuring and creating executor pods. + * A factory class for bootstrapping and creating executor pods with the given bootstrapping + * components. + * + * @param sparkConf Spark configuration + * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto + * user-specified paths into the executor container + * @param initContainerBootstrap an optional component for bootstrapping the executor init-container + * if one is needed, i.e., when there are remote dependencies to + * localize + * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified + * secrets onto user-specified paths into the executor + * init-container */ -private[spark] trait ExecutorPodFactory { - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod -} - -private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) - extends ExecutorPodFactory { +private[spark] class ExecutorPodFactory( + sparkConf: SparkConf, + mountSecretsBootstrap: Option[MountSecretsBootstrap], + initContainerBootstrap: Option[InitContainerBootstrap], + initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) require( @@ -64,11 +64,11 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") private val executorAnnotations = - ConfigurationUtils.parsePrefixedKeyValuePairs( + KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) private val nodeSelector = - ConfigurationUtils.parsePrefixedKeyValuePairs( + KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) @@ -94,7 +94,10 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - override def createExecutorPod( + /** + * Configure and construct an executor pod with the given parameters. + */ + def createExecutorPod( executorId: String, applicationId: String, driverUrl: String, @@ -198,7 +201,7 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .endSpec() .build() - val containerWithExecutorLimitCores = executorLimitCores.map { limitCores => + val containerWithLimitCores = executorLimitCores.map { limitCores => val executorCpuLimitQuantity = new QuantityBuilder(false) .withAmount(limitCores) .build() @@ -209,9 +212,33 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .build() }.getOrElse(executorContainer) - new PodBuilder(executorPod) + val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = + mountSecretsBootstrap.map { bootstrap => + bootstrap.mountSecrets(executorPod, containerWithLimitCores) + }.getOrElse((executorPod, containerWithLimitCores)) + + val (bootstrappedPod, bootstrappedContainer) = + initContainerBootstrap.map { bootstrap => + val podWithInitContainer = bootstrap.bootstrapInitContainer( + PodWithDetachedInitContainer( + maybeSecretsMountedPod, + new ContainerBuilder().build(), + maybeSecretsMountedContainer)) + + val (pod, mayBeSecretsMountedInitContainer) = + initContainerMountSecretsBootstrap.map { bootstrap => + bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) + + val bootstrappedPod = KubernetesUtils.appendInitContainer( + pod, mayBeSecretsMountedInitContainer) + + (bootstrappedPod, podWithInitContainer.mainContainer) + }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) + + new PodBuilder(bootstrappedPod) .editSpec() - .addToContainers(containerWithExecutorLimitCores) + .addToContainers(bootstrappedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b8bb152d1791..a942db6ae02d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,9 +21,9 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.util.ThreadUtils @@ -45,6 +45,59 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { val sparkConf = sc.getConf + val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) + val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) + + if (initContainerConfigMap.isEmpty) { + logWarning("The executor's init-container config map is not specified. Executors will " + + "therefore not attempt to fetch remote or submitted dependencies.") + } + + if (initContainerConfigMapKey.isEmpty) { + logWarning("The executor's init-container config map key is not specified. Executors will " + + "therefore not attempt to fetch remote or submitted dependencies.") + } + + // Only set up the bootstrap if they've provided both the config map key and the config map + // name. The config map might not be provided if init-containers aren't being used to + // bootstrap dependencies. + val initContainerBootstrap = for { + configMap <- initContainerConfigMap + configMapKey <- initContainerConfigMapKey + } yield { + val initContainerImage = sparkConf + .get(INIT_CONTAINER_IMAGE) + .getOrElse(throw new SparkException( + "Must specify the init-container image when there are remote dependencies")) + new InitContainerBootstrap( + initContainerImage, + sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), + sparkConf.get(JARS_DOWNLOAD_LOCATION), + sparkConf.get(FILES_DOWNLOAD_LOCATION), + configMap, + configMapKey, + SPARK_POD_EXECUTOR_ROLE, + sparkConf) + } + + val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { + Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) + } else { + None + } + // Mount user-specified executor secrets also into the executor's init-container. The + // init-container may need credentials in the secrets to be able to download remote + // dependencies. The executor's main container and its init-container share the secrets + // because the init-container is sort of an implementation details and this sharing + // avoids introducing a dedicated configuration property just for the init-container. + val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && + executorSecretNamesToMountPaths.nonEmpty) { + Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) + } else { + None + } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, @@ -54,7 +107,12 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val executorPodFactory = new ExecutorPodFactory( + sparkConf, + mountSecretBootstrap, + initContainerBootstrap, + initContainerMountSecretsBootstrap) + val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala similarity index 51% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 98f9f27da5cd..f193b1f4d366 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -17,25 +17,27 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config.DRIVER_CONTAINER_IMAGE +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ -class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { +class DriverConfigOrchestratorSuite extends SparkFunSuite { - private val NAMESPACE = "default" private val DRIVER_IMAGE = "driver-image" + private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" private val LAUNCH_TIME = 975256L private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { val sparkConf = new SparkConf(false) .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigurationStepsOrchestrator( - NAMESPACE, + val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, Some(mainAppResource), @@ -45,7 +47,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { sparkConf) validateStepTypes( orchestrator, - classOf[BaseDriverConfigurationStep], + classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], classOf[DependencyResolutionStep] @@ -55,8 +57,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false) .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigurationStepsOrchestrator( - NAMESPACE, + val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, Option.empty, @@ -66,16 +67,62 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { sparkConf) validateStepTypes( orchestrator, - classOf[BaseDriverConfigurationStep], + classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep] ) } + test("Submission steps with an init-container.") { + val sparkConf = new SparkConf(false) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BasicDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep], + classOf[DriverInitContainerBootstrapStep]) + } + + test("Submission steps with driver secrets to mount") { + val sparkConf = new SparkConf(false) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BasicDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep], + classOf[DriverMountSecretsStep]) + } + private def validateStepTypes( - orchestrator: DriverConfigurationStepsOrchestrator, + orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps() + val steps = orchestrator.getAllConfigurationSteps assert(steps.size === types.size) assert(steps.map(_.getClass) === types) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala new file mode 100644 index 000000000000..8388c16ded26 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, Pod} + +private[spark] object SecretVolumeUtils { + + def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { + driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + } + + def containerHasVolume( + driverContainer: Container, + volumeName: String, + mountPath: String): Boolean = { + driverContainer.getVolumeMounts.asScala.exists(volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala similarity index 97% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index f7c1b3142cf7..e864c6a16eeb 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -class BaseDriverConfigurationStepSuite extends SparkFunSuite { +class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val APP_ID = "spark-app-id" private val RESOURCE_NAME_PREFIX = "spark" @@ -52,7 +52,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - val submissionStep = new BaseDriverConfigurationStep( + val submissionStep = new BasicDriverConfigurationStep( APP_ID, RESOURCE_NAME_PREFIX, DRIVER_LABELS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala new file mode 100644 index 000000000000..758871e2ba35 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.StringReader +import java.util.Properties + +import scala.collection.JavaConverters._ + +import com.google.common.collect.Maps +import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} +import org.apache.spark.util.Utils + +class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { + + private val CONFIG_MAP_NAME = "spark-init-config-map" + private val CONFIG_MAP_KEY = "spark-init-config-map-key" + + test("The init container bootstrap step should use all of the init container steps") { + val baseDriverSpec = KubernetesDriverSpec( + driverPod = new PodBuilder().build(), + driverContainer = new ContainerBuilder().build(), + driverSparkConf = new SparkConf(false), + otherKubernetesResources = Seq.empty[HasMetadata]) + val initContainerSteps = Seq( + FirstTestInitContainerConfigurationStep, + SecondTestInitContainerConfigurationStep) + val bootstrapStep = new DriverInitContainerBootstrapStep( + initContainerSteps, + CONFIG_MAP_NAME, + CONFIG_MAP_KEY) + + val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) + + assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === + FirstTestInitContainerConfigurationStep.additionalLabels) + val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala + assert(additionalDriverEnv.size === 1) + assert(additionalDriverEnv.head.getName === + FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) + assert(additionalDriverEnv.head.getValue === + FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) + + assert(preparedDriverSpec.otherKubernetesResources.size === 2) + assert(preparedDriverSpec.otherKubernetesResources.contains( + FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) + assert(preparedDriverSpec.otherKubernetesResources.exists { + case configMap: ConfigMap => + val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME + val configMapData = configMap.getData.asScala + val hasCorrectNumberOfEntries = configMapData.size == 1 + val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) + val initContainerProperties = new Properties() + Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { + initContainerProperties.load(_) + } + val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala + val expectedInitContainerProperties = Map( + SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> + SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) + val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties + hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties + + case _ => false + }) + + val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers + assert(initContainers.size() === 1) + val initContainerEnv = initContainers.get(0).getEnv.asScala + assert(initContainerEnv.size === 1) + assert(initContainerEnv.head.getName === + SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) + assert(initContainerEnv.head.getValue === + SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) + + val expectedSparkConf = Map( + INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, + SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> + SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) + assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) + } +} + +private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { + + val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") + val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" + val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" + val additionalKubernetesResource = new SecretBuilder() + .withNewMetadata() + .withName("test-secret") + .endMetadata() + .addToData("secret-key", "secret-value") + .build() + + override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { + val driverPod = new PodBuilder(initContainerSpec.driverPod) + .editOrNewMetadata() + .addToLabels(additionalLabels.asJava) + .endMetadata() + .build() + val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) + .addNewEnv() + .withName(additionalMainContainerEnvKey) + .withValue(additionalMainContainerEnvValue) + .endEnv() + .build() + initContainerSpec.copy( + driverPod = driverPod, + driverContainer = mainContainer, + dependentResources = initContainerSpec.dependentResources ++ + Seq(additionalKubernetesResource)) + } +} + +private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { + val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" + val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" + val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" + val additionalInitContainerPropertyValue = "testvalue" + val additionalDriverSparkConfKey = "spark.driver.testkey" + val additionalDriverSparkConfValue = "spark.driver.testvalue" + + override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { + val initContainer = new ContainerBuilder(initContainerSpec.initContainer) + .addNewEnv() + .withName(additionalInitContainerEnvKey) + .withValue(additionalInitContainerEnvValue) + .endEnv() + .build() + val initContainerProperties = initContainerSpec.properties ++ + Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) + val driverSparkConf = initContainerSpec.driverSparkConf ++ + Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) + initContainerSpec.copy( + initContainer = initContainer, + properties = initContainerProperties, + driverSparkConf = driverSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala new file mode 100644 index 000000000000..9ec0cb55de5a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} + +class DriverMountSecretsStepSuite extends SparkFunSuite { + + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/driver" + + test("mounts all given secrets") { + val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val secretNamesToMountPaths = Map( + SECRET_FOO -> SECRET_MOUNT_PATH, + SECRET_BAR -> SECRET_MOUNT_PATH) + + val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) + val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) + val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) + val driverPodWithSecretsMounted = configuredDriverSpec.driverPod + val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => + assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) + } + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => + assert(SecretVolumeUtils.containerHasVolume( + driverContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala new file mode 100644 index 000000000000..4553f9f6b1d4 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.Config._ + +class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") + private val SPARK_FILES = Seq( + "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") + private val JARS_DOWNLOAD_PATH = "/var/data/jars" + private val FILES_DOWNLOAD_PATH = "/var/data/files" + private val POD_LABEL = Map("bootstrap" -> "true") + private val INIT_CONTAINER_NAME = "init-container" + private val DRIVER_CONTAINER_NAME = "driver-container" + + @Mock + private var podAndInitContainerBootstrap : InitContainerBootstrap = _ + + before { + MockitoAnnotations.initMocks(this) + when(podAndInitContainerBootstrap.bootstrapInitContainer( + any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { + override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { + val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) + pod.copy( + pod = new PodBuilder(pod.pod) + .withNewMetadata() + .addToLabels("bootstrap", "true") + .endMetadata() + .withNewSpec().endSpec() + .build(), + initContainer = new ContainerBuilder() + .withName(INIT_CONTAINER_NAME) + .build(), + mainContainer = new ContainerBuilder() + .withName(DRIVER_CONTAINER_NAME) + .build() + )}}) + } + + test("additionalDriverSparkConf with mix of remote files and jars") { + val baseInitStep = new BasicInitContainerConfigurationStep( + SPARK_JARS, + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + podAndInitContainerBootstrap) + val expectedDriverSparkConf = Map( + JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, + INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", + INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") + val initContainerSpec = InitContainerSpec( + Map.empty[String, String], + Map.empty[String, String], + new Container(), + new Container(), + new Pod, + Seq.empty[HasMetadata]) + val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) + assert(expectedDriverSparkConf === returnContainerSpec.properties) + assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) + assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) + assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala new file mode 100644 index 000000000000..20f2e5bc15df --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class InitContainerConfigOrchestratorSuite extends SparkFunSuite { + + private val DOCKER_IMAGE = "init-container" + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") + private val SPARK_FILES = Seq( + "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") + private val JARS_DOWNLOAD_PATH = "/var/data/jars" + private val FILES_DOWNLOAD_PATH = "/var/data/files" + private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" + private val CUSTOM_LABEL_KEY = "customLabel" + private val CUSTOM_LABEL_VALUE = "customLabelValue" + private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" + private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" + + test("including basic configuration step") { + val sparkConf = new SparkConf(true) + .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) + + val orchestrator = new InitContainerConfigOrchestrator( + SPARK_JARS.take(1), + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + DOCKER_IMAGE_PULL_POLICY, + INIT_CONTAINER_CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY, + sparkConf) + val initSteps = orchestrator.getAllConfigurationSteps + assert(initSteps.lengthCompare(1) == 0) + assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) + } + + test("including step to mount user-specified secrets") { + val sparkConf = new SparkConf(false) + .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) + + val orchestrator = new InitContainerConfigOrchestrator( + SPARK_JARS.take(1), + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + DOCKER_IMAGE_PULL_POLICY, + INIT_CONTAINER_CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY, + sparkConf) + val initSteps = orchestrator.getAllConfigurationSteps + assert(initSteps.length === 2) + assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) + assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala new file mode 100644 index 000000000000..eab4e1765945 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils + +class InitContainerMountSecretsStepSuite extends SparkFunSuite { + + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" + + test("mounts all given secrets") { + val baseInitContainerSpec = InitContainerSpec( + Map.empty, + Map.empty, + new ContainerBuilder().build(), + new ContainerBuilder().build(), + new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), + Seq.empty) + val secretNamesToMountPaths = Map( + SECRET_FOO -> SECRET_MOUNT_PATH, + SECRET_BAR -> SECRET_MOUNT_PATH) + + val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) + val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) + val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( + baseInitContainerSpec) + + val podWithSecretsMounted = configuredInitContainerSpec.driverPod + val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer + + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => + assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => + assert(SecretVolumeUtils.containerHasVolume( + initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala new file mode 100644 index 000000000000..6c557ec4a7c9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.rest.k8s + +import java.io.File +import java.util.UUID + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.Utils + +class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { + + private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") + private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") + + private var downloadJarsDir: File = _ + private var downloadFilesDir: File = _ + private var downloadJarsSecretValue: String = _ + private var downloadFilesSecretValue: String = _ + private var fileFetcher: FileFetcher = _ + + override def beforeAll(): Unit = { + downloadJarsSecretValue = Files.toString( + new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) + downloadFilesSecretValue = Files.toString( + new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) + } + + before { + downloadJarsDir = Utils.createTempDir() + downloadFilesDir = Utils.createTempDir() + fileFetcher = mock[FileFetcher] + } + + after { + downloadJarsDir.delete() + downloadFilesDir.delete() + } + + test("Downloads from remote server should invoke the file fetcher") { + val sparkConf = getSparkConfForRemoteFileDownloads + val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) + initContainerUnderTest.run() + Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) + Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) + Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) + } + + private def getSparkConfForRemoteFileDownloads: SparkConf = { + new SparkConf(true) + .set(INIT_CONTAINER_REMOTE_JARS, + "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") + .set(INIT_CONTAINER_REMOTE_FILES, + "http://localhost:9000/file.txt") + .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) + .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) + } + + private def createTempFile(extension: String): String = { + val dir = Utils.createTempDir() + val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") + Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) + file.getAbsolutePath + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 3a55d7cb37b1..7121a802c69c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -18,15 +18,19 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{Pod, _} -import org.mockito.MockitoAnnotations +import io.fabric8.kubernetes.api.model._ +import org.mockito.{AdditionalAnswers, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + private val driverPodName: String = "driver-pod" private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" @@ -54,7 +58,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactoryImpl(baseConf) + val factory = new ExecutorPodFactory(baseConf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -85,7 +89,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactoryImpl(conf) + val factory = new ExecutorPodFactory(conf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -97,7 +101,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactoryImpl(conf) + val factory = new ExecutorPodFactory(conf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -108,6 +112,74 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor secrets get mounted") { + val conf = baseConf.clone() + + val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) + val factory = new ExecutorPodFactory( + conf, + Some(secretsBootstrap), + None, + None) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName + === "secret1-volume") + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) + .getMountPath === "/var/secret1") + + // check volume mounted. + assert(executor.getSpec.getVolumes.size() === 1) + assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") + + checkOwnerReferences(executor, driverPodUid) + } + + test("init-container bootstrap step adds an init container") { + val conf = baseConf.clone() + val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) + when(initContainerBootstrap.bootstrapInitContainer( + any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + + val factory = new ExecutorPodFactory( + conf, + None, + Some(initContainerBootstrap), + None) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getInitContainers.size() === 1) + checkOwnerReferences(executor, driverPodUid) + } + + test("init-container with secrets mount bootstrap") { + val conf = baseConf.clone() + val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) + when(initContainerBootstrap.bootstrapInitContainer( + any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) + + val factory = new ExecutorPodFactory( + conf, + None, + Some(initContainerBootstrap), + Some(secretsBootstrap)) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getInitContainers.size() === 1) + assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName + === "secret1-volume") + assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) + .getMountPath === "/var/secret1") + + checkOwnerReferences(executor, driverPodUid) + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 9b682f8673c6..45fbcd9cd0de 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . COPY examples /opt/spark/examples @@ -31,4 +31,5 @@ CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 168cd4cb6c57..0f806cf7e148 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . COPY examples /opt/spark/examples @@ -31,4 +31,5 @@ CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile new file mode 100644 index 000000000000..055493188fcb --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM spark-base + +# If this docker file is being used in the context of building your images from a Spark distribution, the docker build +# command should be invoked from the top level directory of the Spark distribution. E.g.: +# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . + +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] From ded6d27e4eb02e4530015a95794e6ed0586faaa7 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 28 Dec 2017 13:53:04 +0900 Subject: [PATCH 023/100] [SPARK-22648][K8S] Add documentation covering init containers and secrets ## What changes were proposed in this pull request? This PR updates the Kubernetes documentation corresponding to the following features/changes in #19954. * Ability to use remote dependencies through the init-container. * Ability to mount user-specified secrets into the driver and executor pods. vanzin jiangxb1987 foxish Author: Yinan Li Closes #20059 from liyinan926/doc-update. --- docs/running-on-kubernetes.md | 194 ++++++++++++++++++++++--------- sbin/build-push-docker-images.sh | 3 +- 2 files changed, 143 insertions(+), 54 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 0048bd90b48a..e491329136a3 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -69,17 +69,17 @@ building using the supplied script, or manually. To launch Spark Pi in cluster mode, -{% highlight bash %} +```bash $ bin/spark-submit \ --master k8s://https://: \ --deploy-mode cluster \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.docker.image= \ - --conf spark.kubernetes.executor.docker.image= \ + --conf spark.kubernetes.driver.container.image= \ + --conf spark.kubernetes.executor.container.image= \ local:///path/to/examples.jar -{% endhighlight %} +``` The Spark master, specified either via passing the `--master` command line argument to `spark-submit` or by setting `spark.master` in the application's configuration, must be a URL with the format `k8s://`. Prefixing the @@ -120,6 +120,54 @@ by their appropriate remote URIs. Also, application dependencies can be pre-moun Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the `SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. +### Using Remote Dependencies +When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods +need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading +the dependencies so the driver and executor containers can use them locally. This requires users to specify the container +image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users +simply add the following option to the `spark-submit` command to specify the init-container image: + +``` +--conf spark.kubernetes.initContainer.image= +``` + +The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and +`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., +the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: + +```bash +$ bin/spark-submit \ + --master k8s://https://: \ + --deploy-mode cluster \ + --name spark-pi \ + --class org.apache.spark.examples.SparkPi \ + --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar + --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 + --conf spark.executor.instances=5 \ + --conf spark.kubernetes.driver.container.image= \ + --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.initContainer.image= + https://path/to/examples.jar +``` + +## Secret Management +Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a +Spark application to access secured services. To mount a user-specified secret into the driver container, users can use +the configuration property of the form `spark.kubernetes.driver.secrets.[SecretName]=`. Similarly, the +configuration property of the form `spark.kubernetes.executor.secrets.[SecretName]=` can be used to mount a +user-specified secret into the executor containers. Note that it is assumed that the secret to be mounted is in the same +namespace as that of the driver and executor pods. For example, to mount a secret named `spark-secret` onto the path +`/etc/secrets` in both the driver and executor containers, add the following options to the `spark-submit` command: + +``` +--conf spark.kubernetes.driver.secrets.spark-secret=/etc/secrets +--conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets +``` + +Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the +init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the +init-container of the executor. + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -275,7 +323,7 @@ specific to Spark on Kubernetes. (none) Container image to use for the driver. - This is usually of the form `example.com/repo/spark-driver:v1.0.0`. + This is usually of the form example.com/repo/spark-driver:v1.0.0. This configuration is required and must be provided by the user. @@ -284,7 +332,7 @@ specific to Spark on Kubernetes. (none) Container image to use for the executors. - This is usually of the form `example.com/repo/spark-executor:v1.0.0`. + This is usually of the form example.com/repo/spark-executor:v1.0.0. This configuration is required and must be provided by the user. @@ -528,51 +576,91 @@ specific to Spark on Kubernetes. - spark.kubernetes.driver.limit.cores - (none) - - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. - - - - spark.kubernetes.executor.limit.cores - (none) - - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. - - - - spark.kubernetes.node.selector.[labelKey] - (none) - - Adds to the node selector of the driver pod and executor pods, with key labelKey and the value as the - configuration's value. For example, setting spark.kubernetes.node.selector.identifier to myIdentifier - will result in the driver pod and executors having a node selector with key identifier and value - myIdentifier. Multiple node selector keys can be added by setting multiple configurations with this prefix. - - - - spark.kubernetes.driverEnv.[EnvironmentVariableName] - (none) - - Add the environment variable specified by EnvironmentVariableName to - the Driver process. The user can specify multiple of these to set multiple environment variables. - - - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - + spark.kubernetes.driver.limit.cores + (none) + + Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + + + spark.kubernetes.executor.limit.cores + (none) + + Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + + + + spark.kubernetes.node.selector.[labelKey] + (none) + + Adds to the node selector of the driver pod and executor pods, with key labelKey and the value as the + configuration's value. For example, setting spark.kubernetes.node.selector.identifier to myIdentifier + will result in the driver pod and executors having a node selector with key identifier and value + myIdentifier. Multiple node selector keys can be added by setting multiple configurations with this prefix. + + + + spark.kubernetes.driverEnv.[EnvironmentVariableName] + (none) + + Add the environment variable specified by EnvironmentVariableName to + the Driver process. The user can specify multiple of these to set multiple environment variables. + + + + spark.kubernetes.mountDependencies.jarsDownloadDir + /var/spark-data/spark-jars + + Location to download jars to in the driver and executors. + This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. + + + + spark.kubernetes.mountDependencies.filesDownloadDir + /var/spark-data/spark-files + + Location to download jars to in the driver and executors. + This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. + + + + spark.kubernetes.mountDependencies.timeout + 300s + + Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into + the driver and executor pods. + + + + spark.kubernetes.mountDependencies.maxSimultaneousDownloads + 5 + + Maximum number of remote dependencies to download simultaneously in a driver or executor pod. + + + + spark.kubernetes.initContainer.image + (none) + + Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. + + + + spark.kubernetes.driver.secrets.[SecretName] + (none) + + Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, + the secret will also be added to the init-container in the driver pod. + + + + spark.kubernetes.executor.secrets.[SecretName] + (none) + + Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, + the secret will also be added to the init-container in the executor pod. + + \ No newline at end of file diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index 4546e98dc207..b3137598692d 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -20,7 +20,8 @@ # with Kubernetes support. declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile ) + [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ + [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) function build { docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . From 76e8a1d7e2619c1e6bd75c399314d2583a86b93b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 28 Dec 2017 20:17:26 +0900 Subject: [PATCH 024/100] [SPARK-22843][R] Adds localCheckpoint in R ## What changes were proposed in this pull request? This PR proposes to add `localCheckpoint(..)` in R API. ```r df <- localCheckpoint(createDataFrame(iris)) ``` ## How was this patch tested? Unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R` Author: hyukjinkwon Closes #20073 from HyukjinKwon/SPARK-22843. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 27 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 22 ++++++++++++++++++++++ 4 files changed, 54 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index dce64e1e607c..4b699de558b4 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -133,6 +133,7 @@ exportMethods("arrange", "isStreaming", "join", "limit", + "localCheckpoint", "merge", "mutate", "na.omit", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b8d732a48586..ace49daf9cd8 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3782,6 +3782,33 @@ setMethod("checkpoint", dataFrame(df) }) +#' localCheckpoint +#' +#' Returns a locally checkpointed version of this SparkDataFrame. Checkpointing can be used to +#' truncate the logical plan, which is especially useful in iterative algorithms where the plan +#' may grow exponentially. Local checkpoints are stored in the executors using the caching +#' subsystem and therefore they are not reliable. +#' +#' @param x A SparkDataFrame +#' @param eager whether to locally checkpoint this SparkDataFrame immediately +#' @return a new locally checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases localCheckpoint,SparkDataFrame-method +#' @rdname localCheckpoint +#' @name localCheckpoint +#' @export +#' @examples +#'\dontrun{ +#' df <- localCheckpoint(df) +#' } +#' @note localCheckpoint since 2.3.0 +setMethod("localCheckpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "localCheckpoint", as.logical(eager)) + dataFrame(df) + }) + #' cube #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5ddaa669f920..d5d0bc9d8a97 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -611,6 +611,10 @@ setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) +#' @rdname localCheckpoint +#' @export +setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) + #' @rdname merge #' @export setGeneric("merge") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 6cc0188dae95..650e7c05f46e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -957,6 +957,28 @@ test_that("setCheckpointDir(), checkpoint() on a DataFrame", { } }) +test_that("localCheckpoint() on a DataFrame", { + if (windows_with_hadoop()) { + # Checkpoint directory shouldn't matter in localCheckpoint. + checkpointDir <- file.path(tempdir(), "lcproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE, recursive = TRUE)) == 0) + setCheckpointDir(checkpointDir) + + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + writeLines(mockLines, textPath) + # Read it lazily and then locally checkpoint eagerly. + df <- read.df(textPath, "text") + df <- localCheckpoint(df, eager = TRUE) + # Here, we remove the source path to check eagerness. + unlink(textPath) + expect_is(df, "SparkDataFrame") + expect_equal(colnames(df), c("value")) + expect_equal(count(df), 3) + + expect_true(length(list.files(path = checkpointDir, all.files = TRUE, recursive = TRUE)) == 0) + } +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) From 1eebfbe192060af3c81cd086bc5d5a7e80d09e77 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 28 Dec 2017 20:18:47 +0900 Subject: [PATCH 025/100] [SPARK-21208][R] Adds setLocalProperty and getLocalProperty in R ## What changes were proposed in this pull request? This PR adds `setLocalProperty` and `getLocalProperty`in R. ```R > df <- createDataFrame(iris) > setLocalProperty("spark.job.description", "Hello world!") > count(df) > setLocalProperty("spark.job.description", "Hi !!") > count(df) ``` 2017-12-25 4 18 07 ```R > print(getLocalProperty("spark.job.description")) NULL > setLocalProperty("spark.job.description", "Hello world!") > print(getLocalProperty("spark.job.description")) [1] "Hello world!" > setLocalProperty("spark.job.description", "Hi !!") > print(getLocalProperty("spark.job.description")) [1] "Hi !!" ``` ## How was this patch tested? Manually tested and a test in `R/pkg/tests/fulltests/test_context.R`. Author: hyukjinkwon Closes #20075 from HyukjinKwon/SPARK-21208. --- R/pkg/NAMESPACE | 4 ++- R/pkg/R/sparkR.R | 45 ++++++++++++++++++++++++++++ R/pkg/tests/fulltests/test_context.R | 33 +++++++++++++++++++- 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4b699de558b4..ce3eec069efe 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -76,7 +76,9 @@ exportMethods("glm", export("setJobGroup", "clearJobGroup", "cancelJobGroup", - "setJobDescription") + "setJobDescription", + "setLocalProperty", + "getLocalProperty") # Export Utility methods export("setLogLevel") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index fb5f1d21fc72..965471f3b07a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -560,10 +560,55 @@ cancelJobGroup <- function(sc, groupId) { #'} #' @note setJobDescription since 2.3.0 setJobDescription <- function(value) { + if (!is.null(value)) { + value <- as.character(value) + } sc <- getSparkContext() invisible(callJMethod(sc, "setJobDescription", value)) } +#' Set a local property that affects jobs submitted from this thread, such as the +#' Spark fair scheduler pool. +#' +#' @param key The key for a local property. +#' @param value The value for a local property. +#' @rdname setLocalProperty +#' @name setLocalProperty +#' @examples +#'\dontrun{ +#' setLocalProperty("spark.scheduler.pool", "poolA") +#'} +#' @note setLocalProperty since 2.3.0 +setLocalProperty <- function(key, value) { + if (is.null(key) || is.na(key)) { + stop("key should not be NULL or NA.") + } + if (!is.null(value)) { + value <- as.character(value) + } + sc <- getSparkContext() + invisible(callJMethod(sc, "setLocalProperty", as.character(key), value)) +} + +#' Get a local property set in this thread, or \code{NULL} if it is missing. See +#' \code{setLocalProperty}. +#' +#' @param key The key for a local property. +#' @rdname getLocalProperty +#' @name getLocalProperty +#' @examples +#'\dontrun{ +#' getLocalProperty("spark.scheduler.pool") +#'} +#' @note getLocalProperty since 2.3.0 +getLocalProperty <- function(key) { + if (is.null(key) || is.na(key)) { + stop("key should not be NULL or NA.") + } + sc <- getSparkContext() + callJMethod(sc, "getLocalProperty", as.character(key)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 77635c5a256b..f0d0a5114f89 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -100,7 +100,6 @@ test_that("job group functions can be called", { setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() - setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) @@ -108,6 +107,38 @@ test_that("job group functions can be called", { sparkR.session.stop() }) +test_that("job description and local properties can be set and got", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + setJobDescription("job description") + expect_equal(getLocalProperty("spark.job.description"), "job description") + setJobDescription(1234) + expect_equal(getLocalProperty("spark.job.description"), "1234") + setJobDescription(NULL) + expect_equal(getLocalProperty("spark.job.description"), NULL) + setJobDescription(NA) + expect_equal(getLocalProperty("spark.job.description"), NULL) + + setLocalProperty("spark.scheduler.pool", "poolA") + expect_equal(getLocalProperty("spark.scheduler.pool"), "poolA") + setLocalProperty("spark.scheduler.pool", NULL) + expect_equal(getLocalProperty("spark.scheduler.pool"), NULL) + setLocalProperty("spark.scheduler.pool", NA) + expect_equal(getLocalProperty("spark.scheduler.pool"), NULL) + + setLocalProperty(4321, 1234) + expect_equal(getLocalProperty(4321), "1234") + setLocalProperty(4321, NULL) + expect_equal(getLocalProperty(4321), NULL) + setLocalProperty(4321, NA) + expect_equal(getLocalProperty(4321), NULL) + + expect_error(setLocalProperty(NULL, "should fail"), "key should not be NULL or NA") + expect_error(getLocalProperty(NULL), "key should not be NULL or NA") + expect_error(setLocalProperty(NA, "should fail"), "key should not be NULL or NA") + expect_error(getLocalProperty(NA), "key should not be NULL or NA") + sparkR.session.stop() +}) + test_that("utility function can be called", { sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") From 755f2f5189a08597fddc90b7f9df4ad0ec6bd2ef Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Dec 2017 21:33:03 +0800 Subject: [PATCH 026/100] [SPARK-20392][SQL][FOLLOWUP] should not add extra AnalysisBarrier ## What changes were proposed in this pull request? I found this problem while auditing the analyzer code. It's dangerous to introduce extra `AnalysisBarrer` during analysis, as the plan inside it will bypass all analysis afterward, which may not be expected. We should only preserve `AnalysisBarrer` but not introduce new ones. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20094 from cloud-fan/barrier. --- .../sql/catalyst/analysis/Analyzer.scala | 191 ++++++++---------- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 2 files changed, 84 insertions(+), 109 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 10b237fb22b9..7f2128ed36a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -665,14 +664,18 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = { - // Remove analysis barrier if any. - val right = EliminateBarriers(originalRight) + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") right.collect { + // For `AnalysisBarrier`, recursively de-duplicate its child. + case oldVersion: AnalysisBarrier + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = dedupRight(left, oldVersion.child) + (oldVersion, AnalysisBarrier(newVersion)) + // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -710,10 +713,10 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - originalRight + right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { + right transformUp { case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { @@ -723,7 +726,6 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - AnalysisBarrier(newRight) } } @@ -958,7 +960,8 @@ class Analyzer( protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, - throws: Boolean = false) = { + throws: Boolean = false): Expression = { + if (expr.resolved) return expr // Resolve expression in one round. // If throws == false or the desired attribute doesn't exist // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. @@ -1079,100 +1082,74 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) - val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(child.output, - Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) - } else if (newOrder != order) { - s.copy(order = newOrder) - } else { - s - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - // in Sort - case ae: AnalysisException => s + case s @ Sort(order, _, child) if !s.resolved && child.resolved => + val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) + val ordering = newOrder.map(_.asInstanceOf[SortOrder]) + if (child.output == newChild.output) { + s.copy(order = ordering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = ordering, child = newChild) + Project(child.output, newSort) } - case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newCond = resolveExpressionRecursively(cond, child) - val requiredAttrs = newCond.references.filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away. - Project(child.output, - Filter(newCond, addMissingAttr(child, missingAttrs))) - } else if (newCond != cond) { - f.copy(condition = newCond) - } else { - f - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - case ae: AnalysisException => f + case f @ Filter(cond, child) if !f.resolved && child.resolved => + val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) + if (child.output == newChild.output) { + f.copy(condition = newCond.head) + } else { + // Add missing attributes and then project them away. + val newFilter = Filter(newCond.head, newChild) + Project(child.output, newFilter) } } - /** - * Add the missing attributes into projectList of Project/Window or aggregateExpressions of - * Aggregate. - */ - private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { - if (missingAttrs.isEmpty) { - return AnalysisBarrier(plan) - } - plan match { - case p: Project => - val missing = missingAttrs -- p.child.outputSet - Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case a: Aggregate => - // all the missing attributes should be grouping expressions - // TODO: push down AggregateExpression - missingAttrs.foreach { attr => - if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { - throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") - } - } - val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs - a.copy(aggregateExpressions = newAggregateExpressions) - case g: Generate => - // If join is false, we will convert it to true for getting from the child the missing - // attributes that its child might have or could have. - val missing = missingAttrs -- g.child.outputSet - g.copy(join = true, child = addMissingAttr(g.child, missing)) - case d: Distinct => - throw new AnalysisException(s"Can't add $missingAttrs to $d") - case u: UnaryNode => - u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) - case other => - throw new AnalysisException(s"Can't add $missingAttrs to $other") - } - } - - /** - * Resolve the expression on a specified logical plan and it's child (recursively), until - * the expression is resolved or meet a non-unary node or Subquery. - */ - @tailrec - private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { - val resolved = resolveExpression(expr, plan) - if (resolved.resolved) { - resolved + private def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + if (exprs.forall(_.resolved)) { + // All given expressions are resolved, no need to continue anymore. + (exprs, plan) } else { plan match { - case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => - resolveExpressionRecursively(resolved, u.child) - case other => resolved + // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via + // its child. + case barrier: AnalysisBarrier => + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) + (newExprs, AnalysisBarrier(newChild)) + + case p: Project => + val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(join = true, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpression(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpression(_, other)), other) } } } @@ -1404,18 +1381,16 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => - apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) - case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved => + case Filter(cond, AnalysisBarrier(agg: Aggregate)) => + apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")() :: Nil, + Alias(cond, "havingCondition")() :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = @@ -1436,7 +1411,7 @@ class Analyzer( // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !aggregate.output.exists(_.semanticEquals(e)) => + !agg.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne @@ -1450,22 +1425,22 @@ class Analyzer( // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { - Project(aggregate.output, + Project(agg.output, Filter(transformedAggregateFilter, - aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) + agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) } else { - filter + f } } else { - filter + f } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. - case ae: AnalysisException => filter + case ae: AnalysisException => f } - case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c11e37a51664..07ae3ae94584 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1562,7 +1562,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("multi-insert with lateral view") { - withTempView("t1") { + withTempView("source") { spark.range(10) .select(array($"id", $"id" + 1).as("arr"), $"id") .createOrReplaceTempView("source") From 28778174208664327b75915e83ae5e611360eef3 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 28 Dec 2017 21:49:37 +0800 Subject: [PATCH 027/100] [SPARK-22917][SQL] Should not try to generate histogram for empty/null columns ## What changes were proposed in this pull request? For empty/null column, the result of `ApproximatePercentile` is null. Then in `ApproxCountDistinctForIntervals`, a `MatchError` (for `endpoints`) will be thrown if we try to generate histogram for that column. Besides, there is no need to generate histogram for such column. In this patch, we exclude such column when generating histogram. ## How was this patch tested? Enhanced test cases for empty/null columns. Author: Zhenhua Wang Closes #20102 from wzhfy/no_record_hgm_bug. --- .../command/AnalyzeColumnCommand.scala | 7 ++++++- .../spark/sql/StatisticsCollectionSuite.scala | 21 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index e3bb4d357b39..1122522ccb4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -143,7 +143,12 @@ case class AnalyzeColumnCommand( val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation)) .executedPlan.executeTake(1).head attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) => - attributePercentiles += attr -> percentilesRow.getArray(i) + val percentiles = percentilesRow.getArray(i) + // When there is no non-null value, `percentiles` is null. In such case, there is no + // need to generate histogram. + if (percentiles != null) { + attributePercentiles += attr -> percentiles + } } } AttributeMap(attributePercentiles.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index fba5d2652d3f..b11e79853205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -85,13 +85,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("analyze empty table") { val table = "emptyTable" withTable(table) { - sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + val df = Seq.empty[Int].toDF("key") + df.write.format("json").saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats1.get.sizeInBytes == 0) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) assert(fetchedStats2.get.sizeInBytes == 0) + + val expectedColStat = + "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + + // There won't be histogram for empty column. + Seq("true", "false").foreach { histogramEnabled => + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> histogramEnabled) { + checkColStats(df, mutable.LinkedHashMap(expectedColStat)) + } + } } } @@ -178,7 +189,13 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val expectedColStats = dataTypes.map { case (tpe, idx) => (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) } - checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + + // There won't be histograms for null columns. + Seq("true", "false").foreach { histogramEnabled => + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> histogramEnabled) { + checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + } + } } test("number format in statistics") { From 5536f3181c1e77c70f01d6417407d218ea48b961 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 28 Dec 2017 09:43:50 -0600 Subject: [PATCH 028/100] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up a few Java linter errors for Apache Spark 2.3 release. ## How was this patch tested? ```bash $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` We can see the result from [Travis CI](https://travis-ci.org/dongjoon-hyun/spark/builds/322470787), too. Author: Dongjoon Hyun Closes #20101 from dongjoon-hyun/fix-java-lint. --- .../org/apache/spark/memory/MemoryConsumer.java | 3 ++- .../streaming/kinesis/KinesisInitialPositions.java | 14 ++++++++------ .../parquet/VectorizedColumnReader.java | 3 +-- .../parquet/VectorizedParquetRecordReader.java | 4 ++-- .../sql/execution/vectorized/ColumnarRow.java | 3 ++- .../spark/sql/sources/v2/SessionConfigSupport.java | 3 --- .../v2/streaming/ContinuousReadSupport.java | 5 ++++- .../v2/streaming/ContinuousWriteSupport.java | 6 +++--- .../sql/sources/v2/streaming/reader/Offset.java | 3 ++- .../v2/streaming/reader/PartitionOffset.java | 1 - .../hive/service/cli/operation/SQLOperation.java | 1 - 11 files changed, 24 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index a7bd4b3799a2..115e1fbb79a2 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -154,6 +154,7 @@ private void throwOom(final MemoryBlock page, final long required) { taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); - throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + + got); } } diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java index 206e1e469903..b5f5ab0e9054 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java @@ -67,9 +67,10 @@ public Date getTimestamp() { /** - * Returns instance of [[KinesisInitialPosition]] based on the passed [[InitialPositionInStream]]. - * This method is used in KinesisUtils for translating the InitialPositionInStream - * to InitialPosition. This function would be removed when we deprecate the KinesisUtils. + * Returns instance of [[KinesisInitialPosition]] based on the passed + * [[InitialPositionInStream]]. This method is used in KinesisUtils for translating the + * InitialPositionInStream to InitialPosition. This function would be removed when we deprecate + * the KinesisUtils. * * @return [[InitialPosition]] */ @@ -83,9 +84,10 @@ public static KinesisInitialPosition fromKinesisInitialPosition( // InitialPositionInStream.AT_TIMESTAMP is not supported. // Use InitialPosition.atTimestamp(timestamp) instead. throw new UnsupportedOperationException( - "Only InitialPositionInStream.LATEST and InitialPositionInStream.TRIM_HORIZON " + - "supported in initialPositionInStream(). Please use the initialPosition() from " + - "builder API in KinesisInputDStream for using InitialPositionInStream.AT_TIMESTAMP"); + "Only InitialPositionInStream.LATEST and InitialPositionInStream." + + "TRIM_HORIZON supported in initialPositionInStream(). Please use " + + "the initialPosition() from builder API in KinesisInputDStream for " + + "using InitialPositionInStream.AT_TIMESTAMP"); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 3ba180860c32..c120863152a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -31,7 +31,6 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; -import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -96,7 +95,7 @@ public class VectorizedColumnReader { private final OriginalType originalType; // The timezone conversion to apply to int96 timestamps. Null if no conversion. private final TimeZone convertTz; - private final static TimeZone UTC = DateTimeUtils.TimeZoneUTC(); + private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC(); public VectorizedColumnReader( ColumnDescriptor descriptor, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 14f2a58d638f..6c157e85d411 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -79,8 +79,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private boolean[] missingColumns; /** - * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to workaround - * incompatibilities between different engines when writing timestamp values. + * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to + * workaround incompatibilities between different engines when writing timestamp values. */ private TimeZone convertTz = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index 95c0d09873d6..8bb33ed5b78c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -28,7 +28,8 @@ * to be reused, callers should copy the data out if it needs to be stored. */ public final class ColumnarRow extends InternalRow { - // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // The data for this row. + // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 0b5b6ac675f2..3cb020d2e083 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -19,9 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; -import java.util.List; -import java.util.Map; - /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 8837bae6156b..3136cee1f655 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -39,5 +39,8 @@ public interface ContinuousReadSupport extends DataSourceV2 { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - ContinuousReader createContinuousReader(Optional schema, String checkpointLocation, DataSourceV2Options options); + ContinuousReader createContinuousReader( + Optional schema, + String checkpointLocation, + DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java index ec15e436672b..dee493cadb71 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java @@ -39,9 +39,9 @@ public interface ContinuousWriteSupport extends BaseStreamingSink { * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link DataSourceV2Writer} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index 517fdab16684..60b87f2ac075 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -42,7 +42,8 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of @Override public boolean equals(Object obj) { if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) { - return this.json().equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); + return this.json() + .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); } else { return false; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index 729a6123034f..eca0085c8a8c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -26,5 +26,4 @@ * These offsets must be serializable. */ public interface PartitionOffset extends Serializable { - } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index fd9108eb53ca..70c27948de61 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -42,7 +42,6 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse; -import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.SerDe; From 8f6d5734d504be89249fb6744ef4067be2e889fc Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Thu, 28 Dec 2017 11:44:06 -0600 Subject: [PATCH 029/100] [SPARK-22875][BUILD] Assembly build fails for a high user id ## What changes were proposed in this pull request? Add tarLongFileMode=posix configuration for the assembly plugin ## How was this patch tested? Reran build successfully ``` ./build/mvn package -Pbigtop-dist -DskipTests -rf :spark-assembly_2.11 [INFO] Spark Project Assembly ............................. SUCCESS [ 23.082 s] ``` Author: Gera Shegalov Closes #20055 from gerashegalov/gera/tarLongFileMode. --- pom.xml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pom.xml b/pom.xml index 92f897095f08..34793b0c4170 100644 --- a/pom.xml +++ b/pom.xml @@ -2295,6 +2295,9 @@ org.apache.maven.plugins maven-assembly-plugin 3.1.0 + + posix + org.apache.maven.plugins From 9c21ece35edc175edde949175a4ce701679824c8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 28 Dec 2017 15:41:16 -0600 Subject: [PATCH 030/100] [SPARK-22836][UI] Show driver logs in UI when available. Port code from the old executors listener to the new one, so that the driver logs present in the application start event are kept. Author: Marcelo Vanzin Closes #20038 from vanzin/SPARK-22836. --- .../spark/status/AppStatusListener.scala | 11 +++++++++++ .../spark/status/AppStatusListenerSuite.scala | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 525329713732..487a782e865e 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -119,6 +119,17 @@ private[spark] class AppStatusListener( kvstore.write(new ApplicationInfoWrapper(appInfo)) kvstore.write(appSummary) + + // Update the driver block manager with logs from this event. The SparkContext initialization + // code registers the driver before this event is sent. + event.driverLogs.foreach { logs => + val driver = liveExecutors.get(SparkContext.DRIVER_IDENTIFIER) + .orElse(liveExecutors.get(SparkContext.LEGACY_DRIVER_IDENTIFIER)) + driver.foreach { d => + d.executorLogs = logs.toMap + update(d, System.nanoTime()) + } + } } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index c0b3a79fe981..997c7de8dd02 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -942,6 +942,24 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("driver logs") { + val listener = new AppStatusListener(store, conf, true) + + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(time, driver, 42L)) + listener.onApplicationStart(SparkListenerApplicationStart( + "name", + Some("id"), + time, + "user", + Some("attempt"), + Some(Map("stdout" -> "file.txt")))) + + check[ExecutorSummaryWrapper](SparkContext.DRIVER_IDENTIFIER) { d => + assert(d.info.executorLogs("stdout") === "file.txt") + } + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { From 613b71a1230723ed556239b8b387722043c67252 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 29 Dec 2017 06:58:38 +0800 Subject: [PATCH 031/100] [SPARK-22890][TEST] Basic tests for DateTimeOperations ## What changes were proposed in this pull request? Test Coverage for `DateTimeOperations`, this is a Sub-tasks for [SPARK-22722](https://issues.apache.org/jira/browse/SPARK-22722). ## How was this patch tested? N/A Author: Yuming Wang Closes #20061 from wangyum/SPARK-22890. --- .../native/dateTimeOperations.sql | 60 +++ .../native/dateTimeOperations.sql.out | 349 ++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +- 3 files changed, 410 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql new file mode 100644 index 000000000000..1e9822186796 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql @@ -0,0 +1,60 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +select cast(1 as tinyint) + interval 2 day; +select cast(1 as smallint) + interval 2 day; +select cast(1 as int) + interval 2 day; +select cast(1 as bigint) + interval 2 day; +select cast(1 as float) + interval 2 day; +select cast(1 as double) + interval 2 day; +select cast(1 as decimal(10, 0)) + interval 2 day; +select cast('2017-12-11' as string) + interval 2 day; +select cast('2017-12-11 09:30:00' as string) + interval 2 day; +select cast('1' as binary) + interval 2 day; +select cast(1 as boolean) + interval 2 day; +select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day; +select cast('2017-12-11 09:30:00' as date) + interval 2 day; + +select interval 2 day + cast(1 as tinyint); +select interval 2 day + cast(1 as smallint); +select interval 2 day + cast(1 as int); +select interval 2 day + cast(1 as bigint); +select interval 2 day + cast(1 as float); +select interval 2 day + cast(1 as double); +select interval 2 day + cast(1 as decimal(10, 0)); +select interval 2 day + cast('2017-12-11' as string); +select interval 2 day + cast('2017-12-11 09:30:00' as string); +select interval 2 day + cast('1' as binary); +select interval 2 day + cast(1 as boolean); +select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp); +select interval 2 day + cast('2017-12-11 09:30:00' as date); + +select cast(1 as tinyint) - interval 2 day; +select cast(1 as smallint) - interval 2 day; +select cast(1 as int) - interval 2 day; +select cast(1 as bigint) - interval 2 day; +select cast(1 as float) - interval 2 day; +select cast(1 as double) - interval 2 day; +select cast(1 as decimal(10, 0)) - interval 2 day; +select cast('2017-12-11' as string) - interval 2 day; +select cast('2017-12-11 09:30:00' as string) - interval 2 day; +select cast('1' as binary) - interval 2 day; +select cast(1 as boolean) - interval 2 day; +select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day; +select cast('2017-12-11 09:30:00' as date) - interval 2 day; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out new file mode 100644 index 000000000000..12c1d1617679 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out @@ -0,0 +1,349 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 40 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select cast(1 as tinyint) + interval 2 day +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) + interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7 + + +-- !query 2 +select cast(1 as smallint) + interval 2 day +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) + interval 2 days)' (smallint and calendarinterval).; line 1 pos 7 + + +-- !query 3 +select cast(1 as int) + interval 2 day +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) + interval 2 days)' (int and calendarinterval).; line 1 pos 7 + + +-- !query 4 +select cast(1 as bigint) + interval 2 day +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) + interval 2 days)' (bigint and calendarinterval).; line 1 pos 7 + + +-- !query 5 +select cast(1 as float) + interval 2 day +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) + interval 2 days)' (float and calendarinterval).; line 1 pos 7 + + +-- !query 6 +select cast(1 as double) + interval 2 day +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) + interval 2 days)' (double and calendarinterval).; line 1 pos 7 + + +-- !query 7 +select cast(1 as decimal(10, 0)) + interval 2 day +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7 + + +-- !query 8 +select cast('2017-12-11' as string) + interval 2 day +-- !query 8 schema +struct +-- !query 8 output +2017-12-13 00:00:00 + + +-- !query 9 +select cast('2017-12-11 09:30:00' as string) + interval 2 day +-- !query 9 schema +struct +-- !query 9 output +2017-12-13 09:30:00 + + +-- !query 10 +select cast('1' as binary) + interval 2 day +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + interval 2 days)' (binary and calendarinterval).; line 1 pos 7 + + +-- !query 11 +select cast(1 as boolean) + interval 2 day +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) + interval 2 days)' (boolean and calendarinterval).; line 1 pos 7 + + +-- !query 12 +select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day +-- !query 12 schema +struct +-- !query 12 output +2017-12-13 09:30:00 + + +-- !query 13 +select cast('2017-12-11 09:30:00' as date) + interval 2 day +-- !query 13 schema +struct +-- !query 13 output +2017-12-13 + + +-- !query 14 +select interval 2 day + cast(1 as tinyint) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS TINYINT))' (calendarinterval and tinyint).; line 1 pos 7 + + +-- !query 15 +select interval 2 day + cast(1 as smallint) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS SMALLINT))' (calendarinterval and smallint).; line 1 pos 7 + + +-- !query 16 +select interval 2 day + cast(1 as int) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS INT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS INT))' (calendarinterval and int).; line 1 pos 7 + + +-- !query 17 +select interval 2 day + cast(1 as bigint) +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BIGINT))' (calendarinterval and bigint).; line 1 pos 7 + + +-- !query 18 +select interval 2 day + cast(1 as float) +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS FLOAT))' (calendarinterval and float).; line 1 pos 7 + + +-- !query 19 +select interval 2 day + cast(1 as double) +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DOUBLE))' (calendarinterval and double).; line 1 pos 7 + + +-- !query 20 +select interval 2 day + cast(1 as decimal(10, 0)) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' (calendarinterval and decimal(10,0)).; line 1 pos 7 + + +-- !query 21 +select interval 2 day + cast('2017-12-11' as string) +-- !query 21 schema +struct +-- !query 21 output +2017-12-13 00:00:00 + + +-- !query 22 +select interval 2 day + cast('2017-12-11 09:30:00' as string) +-- !query 22 schema +struct +-- !query 22 output +2017-12-13 09:30:00 + + +-- !query 23 +select interval 2 day + cast('1' as binary) +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(interval 2 days + CAST('1' AS BINARY))' (calendarinterval and binary).; line 1 pos 7 + + +-- !query 24 +select interval 2 day + cast(1 as boolean) +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BOOLEAN))' (calendarinterval and boolean).; line 1 pos 7 + + +-- !query 25 +select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp) +-- !query 25 schema +struct +-- !query 25 output +2017-12-13 09:30:00 + + +-- !query 26 +select interval 2 day + cast('2017-12-11 09:30:00' as date) +-- !query 26 schema +struct +-- !query 26 output +2017-12-13 + + +-- !query 27 +select cast(1 as tinyint) - interval 2 day +-- !query 27 schema +struct<> +-- !query 27 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) - interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7 + + +-- !query 28 +select cast(1 as smallint) - interval 2 day +-- !query 28 schema +struct<> +-- !query 28 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) - interval 2 days)' (smallint and calendarinterval).; line 1 pos 7 + + +-- !query 29 +select cast(1 as int) - interval 2 day +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) - interval 2 days)' (int and calendarinterval).; line 1 pos 7 + + +-- !query 30 +select cast(1 as bigint) - interval 2 day +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) - interval 2 days)' (bigint and calendarinterval).; line 1 pos 7 + + +-- !query 31 +select cast(1 as float) - interval 2 day +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) - interval 2 days)' (float and calendarinterval).; line 1 pos 7 + + +-- !query 32 +select cast(1 as double) - interval 2 day +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) - interval 2 days)' (double and calendarinterval).; line 1 pos 7 + + +-- !query 33 +select cast(1 as decimal(10, 0)) - interval 2 day +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7 + + +-- !query 34 +select cast('2017-12-11' as string) - interval 2 day +-- !query 34 schema +struct +-- !query 34 output +2017-12-09 00:00:00 + + +-- !query 35 +select cast('2017-12-11 09:30:00' as string) - interval 2 day +-- !query 35 schema +struct +-- !query 35 output +2017-12-09 09:30:00 + + +-- !query 36 +select cast('1' as binary) - interval 2 day +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - interval 2 days)' (binary and calendarinterval).; line 1 pos 7 + + +-- !query 37 +select cast(1 as boolean) - interval 2 day +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) - interval 2 days)' (boolean and calendarinterval).; line 1 pos 7 + + +-- !query 38 +select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day +-- !query 38 schema +struct +-- !query 38 output +2017-12-09 09:30:00 + + +-- !query 39 +select cast('2017-12-11 09:30:00' as date) - interval 2 day +-- !query 39 schema +struct +-- !query 39 output +2017-12-09 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1972dec7b86c..5e077285ade5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.net.{MalformedURLException, URL} -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} @@ -2760,17 +2760,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-22894: DateTimeOperations should accept SQL like string type") { - val date = "2017-12-24" - val str = sql(s"SELECT CAST('$date' as STRING) + interval 2 months 2 seconds") - val dt = sql(s"SELECT CAST('$date' as DATE) + interval 2 months 2 seconds") - val ts = sql(s"SELECT CAST('$date' as TIMESTAMP) + interval 2 months 2 seconds") - - checkAnswer(str, Row("2018-02-24 00:00:02") :: Nil) - checkAnswer(dt, Row(Date.valueOf("2018-02-24")) :: Nil) - checkAnswer(ts, Row(Timestamp.valueOf("2018-02-24 00:00:02")) :: Nil) - } - // Only New OrcFileFormat supports this Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, "parquet").foreach { format => From cfcd746689c2b84824745fa6d327ffb584c7a17d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 28 Dec 2017 17:00:49 -0600 Subject: [PATCH 032/100] [SPARK-11035][CORE] Add in-process Spark app launcher. This change adds a new launcher that allows applications to be run in a separate thread in the same process as the calling code. To achieve that, some code from the child process implementation was moved to abstract classes that implement the common functionality, and the new launcher inherits from those. The new launcher was added as a new class, instead of implemented as a new option to the existing SparkLauncher, to avoid ambigous APIs. For example, SparkLauncher has ways to set the child app's environment, modify SPARK_HOME, or control the logging of the child process, none of which apply to in-process apps. The in-process launcher has limitations: it needs Spark in the context class loader of the calling thread, and it's bound by Spark's current limitation of a single client-mode application per JVM. It also relies on the recently added SparkApplication trait to make sure different apps don't mess up each other's configuration, so config isolation is currently limited to cluster mode. I also chose to keep the same socket-based communication for in-process apps, even though it might be possible to avoid it for in-process mode. That helps both implementations share more code. Tested with new and existing unit tests, and with a simple app that uses the launcher; also made sure the app ran fine with older launcher jar to check binary compatibility. Author: Marcelo Vanzin Closes #19591 from vanzin/SPARK-11035. --- core/pom.xml | 8 + .../spark/launcher/LauncherBackend.scala | 11 +- .../cluster/StandaloneSchedulerBackend.scala | 1 + .../local/LocalSchedulerBackend.scala | 1 + .../spark/launcher/SparkLauncherSuite.java | 70 +++- .../spark/launcher/AbstractAppHandle.java | 129 +++++++ .../spark/launcher/AbstractLauncher.java | 307 +++++++++++++++ .../spark/launcher/ChildProcAppHandle.java | 117 +----- .../spark/launcher/InProcessAppHandle.java | 83 +++++ .../spark/launcher/InProcessLauncher.java | 110 ++++++ .../spark/launcher/LauncherProtocol.java | 6 + .../apache/spark/launcher/LauncherServer.java | 113 +++--- .../apache/spark/launcher/SparkLauncher.java | 351 +++++------------- .../launcher/SparkSubmitCommandBuilder.java | 2 +- .../apache/spark/launcher/package-info.java | 28 +- .../org/apache/spark/launcher/BaseSuite.java | 35 +- .../launcher/ChildProcAppHandleSuite.java | 16 - .../launcher/InProcessLauncherSuite.java | 170 +++++++++ .../spark/launcher/LauncherServerSuite.java | 82 ++-- .../MesosCoarseGrainedSchedulerBackend.scala | 2 + .../org/apache/spark/deploy/yarn/Client.scala | 2 + 21 files changed, 1139 insertions(+), 505 deletions(-) create mode 100644 launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java create mode 100644 launcher/src/test/java/org/apache/spark/launcher/InProcessLauncherSuite.java diff --git a/core/pom.xml b/core/pom.xml index fa138d3e7a4e..0a5bd958fc9c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -351,6 +351,14 @@ spark-tags_${scala.binary.version} + + org.apache.spark + spark-launcher_${scala.binary.version} + ${project.version} + tests + test + + 0.10.2 - 4.5.2 - 4.4.4 + 4.5.4 + 4.4.8 3.1 3.4.1 From ea0a5eef2238daa68a15b60a6f1a74c361216140 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 31 Dec 2017 02:50:00 +0900 Subject: [PATCH 055/100] [SPARK-22924][SPARKR] R API for sortWithinPartitions ## What changes were proposed in this pull request? Add to `arrange` the option to sort only within partition ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20118 from felixcheung/rsortwithinpartition. --- R/pkg/R/DataFrame.R | 14 ++++++++++---- R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ace49daf9cd8..fe238f6dd4eb 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param ... additional sorting fields #' @param decreasing a logical argument indicating sorting order for columns when #' a character vector is specified for col +#' @param withinPartitions a logical argument indicating whether to sort only within each partition #' @return A SparkDataFrame where all elements are sorted. #' @family SparkDataFrame functions #' @aliases arrange,SparkDataFrame,Column-method @@ -2312,16 +2313,21 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) #' arrange(df, "col1", "col2", decreasing = c(TRUE, FALSE)) +#' arrange(df, "col1", "col2", withinPartitions = TRUE) #' } #' @note arrange(SparkDataFrame, Column) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "Column"), - function(x, col, ...) { + function(x, col, ..., withinPartitions = FALSE) { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", jcols) + if (withinPartitions) { + sdf <- callJMethod(x@sdf, "sortWithinPartitions", jcols) + } else { + sdf <- callJMethod(x@sdf, "sort", jcols) + } dataFrame(sdf) }) @@ -2332,7 +2338,7 @@ setMethod("arrange", #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), - function(x, col, ..., decreasing = FALSE) { + function(x, col, ..., decreasing = FALSE, withinPartitions = FALSE) { # all sorting columns by <- list(col, ...) @@ -2356,7 +2362,7 @@ setMethod("arrange", } }) - do.call("arrange", c(x, jcols)) + do.call("arrange", c(x, jcols, withinPartitions = withinPartitions)) }) #' @rdname arrange diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 1b7d53fd73a0..5197838eaac6 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2130,6 +2130,11 @@ test_that("arrange() and orderBy() on a DataFrame", { sorted7 <- arrange(df, "name", decreasing = FALSE) expect_equal(collect(sorted7)[2, "age"], 19) + + df <- createDataFrame(cars, numPartitions = 10) + expect_equal(getNumPartitions(df), 10) + sorted8 <- arrange(df, "dist", withinPartitions = TRUE) + expect_equal(collect(sorted8)[5:6, "dist"], c(22, 10)) }) test_that("filter() on a DataFrame", { From ee3af15fea18356a9223d61cfe6aaa98ab4dc733 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Sun, 31 Dec 2017 14:47:23 +0800 Subject: [PATCH 056/100] [SPARK-22363][SQL][TEST] Add unit test for Window spilling ## What changes were proposed in this pull request? There is already test using window spilling, but the test coverage is not ideal. In this PR the already existing test was fixed and additional cases added. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20022 from gaborgsomogyi/SPARK-22363. --- .../sql/DataFrameWindowFunctionsSuite.scala | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index ea725af8d1ad..01c988ecc372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -518,9 +519,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + test("Window spill with less than the inMemoryThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "2", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + + test("Window spill with more than the inMemoryThreshold but less than the spillThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + + test("Window spill with more than the inMemoryThreshold and spillThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { + assertSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + test("SPARK-21258: complex object in combination with spilling") { // Make sure we trigger the spilling path. - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { val sampleSchema = new StructType(). add("f0", StringType). add("f1", LongType). @@ -558,7 +596,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ - spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + assertSpilled(sparkContext, "select") { + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } } } } From cfbe11e8164c04cd7d388e4faeded21a9331dac4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 31 Dec 2017 15:06:54 +0800 Subject: [PATCH 057/100] [SPARK-22895][SQL] Push down the deterministic predicates that are after the first non-deterministic ## What changes were proposed in this pull request? Currently, we do not guarantee an order evaluation of conjuncts in either Filter or Join operator. This is also true to the mainstream RDBMS vendors like DB2 and MS SQL Server. Thus, we should also push down the deterministic predicates that are after the first non-deterministic, if possible. ## How was this patch tested? Updated the existing test cases. Author: gatorsmile Closes #20069 from gatorsmile/morePushDown. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/optimizer/Optimizer.scala | 40 ++++++++----------- .../optimizer/FilterPushdownSuite.scala | 33 +++++++-------- .../v2/PushDownOperatorsToDataSource.scala | 10 ++--- .../execution/python/ExtractPythonUDFs.scala | 6 +-- .../StreamingSymmetricHashJoinHelper.scala | 5 +-- .../python/BatchEvalPythonExecSuite.scala | 10 +++-- ...treamingSymmetricHashJoinHelperSuite.scala | 14 +++---- 8 files changed, 54 insertions(+), 65 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4b5f56c44444..dc3e384008d2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1636,6 +1636,7 @@ options. - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. + - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index eeb1b130c578..0d4b02c6e7d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -805,15 +805,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) @@ -835,14 +835,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(partitionAttrs) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) @@ -854,7 +854,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic) if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) @@ -878,13 +878,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } case filter @ Filter(condition, watermark: EventTimeWatermark) => - // We can only push deterministic predicates which don't reference the watermark attribute. - // We could in theory span() only on determinism and pull out deterministic predicates - // on the watermark separately. But it seems unnecessary and a bit confusing to not simply - // use the prefix as we do for nondeterminism in other cases. - - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span( - p => p.deterministic && !p.references.contains(watermark.eventTime)) + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p => + p.deterministic && !p.references.contains(watermark.eventTime) + } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduceLeft(And) @@ -925,14 +921,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // come from grandchild. // TODO: non-deterministic predicates could be pushed through some operators that do not change // the rows. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(grandchild.outputSet) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val newChild = insertFilter(pushDown.reduceLeft(And)) @@ -975,23 +971,19 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions or filter predicates (on a given join's output) into three * categories based on the attributes required to evaluate them. Note that we explicitly exclude - * on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or + * non-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or * canEvaluateInRight to prevent pushing these predicates on either side of the join. * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { - // Note: In order to ensure correctness, it's important to not change the relative ordering of - // any deterministic expression that follows a non-deterministic expression. To achieve this, - // we only consider pushing down those expressions that precede the first non-deterministic - // expression in the condition. - val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) + val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic) val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index a9c2306e37fd..85a5e979f602 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -831,9 +831,9 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Union(Seq( - testRelation.where('a === 2L), - testRelation2.where('d === 2L))) - .where('b + Rand(10).as("rnd") === 3 && 'c > 5L) + testRelation.where('a === 2L && 'c > 5L), + testRelation2.where('d === 2L && 'f > 5L))) + .where('b + Rand(10).as("rnd") === 3) .analyze comparePlans(optimized, correctAnswer) @@ -1134,12 +1134,13 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - // Verify that all conditions preceding the first non-deterministic condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 && "x.a".attr === Rand(10) && "y.b".attr === 5)) - val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), - condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) + val correctAnswer = + x.where("x.a".attr === 5).join(y.where("y.a".attr === 5 && "y.b".attr === 5), + condition = Some("x.a".attr === Rand(10))) // CheckAnalysis will ensure nondeterministic expressions not appear in join condition. // TODO support nondeterministic expressions in join condition. @@ -1147,16 +1148,16 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } - test("watermark pushdown: no pushdown on watermark attribute") { + test("watermark pushdown: no pushdown on watermark attribute #1") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('b, interval, testRelation) .where('a === 5 && 'b === 10 && 'c === 5) val correctAnswer = EventTimeWatermark( - 'b, interval, testRelation.where('a === 5)) - .where('b === 10 && 'c === 5) + 'b, interval, testRelation.where('a === 5 && 'c === 5)) + .where('b === 10) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) @@ -1165,7 +1166,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown for nondeterministic filter") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('c, interval, testRelation) .where('a === 5 && 'b === Rand(10) && 'c === 5) @@ -1180,7 +1181,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: full pushdown") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('c, interval, testRelation) .where('a === 5 && 'b === 10) @@ -1191,15 +1192,15 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } - test("watermark pushdown: empty pushdown") { + test("watermark pushdown: no pushdown on watermark attribute #2") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down - // by the optimizer and others are not. val originalQuery = EventTimeWatermark('a, interval, testRelation) .where('a === 5 && 'b === 10) + val correctAnswer = EventTimeWatermark( + 'a, interval, testRelation.where('b === 10)).where('a === 5) - comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze, + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 0c1708131ae4..df034adf1e7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -40,12 +40,8 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - // Non-deterministic expressions are stateful and we must keep the input sequence unchanged - // to avoid changing the result. This means, we can't evaluate the filter conditions that - // are after the first non-deterministic condition ahead. Here we only try to push down - // deterministic conditions that are before the first non-deterministic condition. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => @@ -74,7 +70,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel case _ => candidates } - val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) if (withFilter.output == fields) { withFilter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index f5a4cbc4793e..2f53fe788c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -202,12 +202,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def trySplitFilter(plan: SparkPlan): SparkPlan = { plan match { case filter: FilterExec => - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) + FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 167e991ca62f..217e98a6419e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -72,8 +72,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { * left AND right AND joined is equivalent to full. * * Note that left and right do not necessarily contain *all* conjuncts which satisfy - * their condition. Any conjuncts after the first nondeterministic one are treated as - * nondeterministic for purposes of the split. + * their condition. * * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. @@ -111,7 +110,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { // Span rather than partition, because nondeterministic expressions don't commute // across AND. val (deterministicConjuncts, nonDeterministicConjuncts) = - splitConjunctivePredicates(condition.get).span(_.deterministic) + splitConjunctivePredicates(condition.get).partition(_.deterministic) val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => cond.references.subsetOf(left.outputSet) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 9e4a2e877695..d456c931f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -75,13 +75,17 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { assert(qualifiedPlanNodes.size == 2) } - test("Python UDF: no push down on predicates starting from the first non-deterministic") { + test("Python UDF: push down on deterministic predicates after the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f + case f @ FilterExec( + And(_: AttributeReference, _: GreaterThan), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } - assert(qualifiedPlanNodes.size == 1) + assert(qualifiedPlanNodes.size == 2) } test("Python UDF refers to the attributes from more than one child") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala index 2a854e37bf0d..69b715489534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.streaming import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates import org.apache.spark.sql.types._ @@ -95,19 +93,17 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { } test("conjuncts after nondeterministic") { - // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't - // commute across it. val predicate = - (rand() > lit(0) + (rand(9) > lit(0) && leftColA > leftColB && rightColC > rightColD && leftColA === rightColC && lit(1) === lit(1)).expr val split = JoinConditionSplitPredicates(Some(predicate), left, right) - assert(split.leftSideOnly.isEmpty) - assert(split.rightSideOnly.isEmpty) - assert(split.bothSides.contains(predicate)) + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && rand(9) > lit(0)).expr)) assert(split.full.contains(predicate)) } From 3d8837e59aadd726805371041567ceff375194c0 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 31 Dec 2017 14:39:24 +0200 Subject: [PATCH 058/100] [SPARK-22397][ML] add multiple columns support to QuantileDiscretizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? add multi columns support to QuantileDiscretizer. When calculating the splits, we can either merge together all the probabilities into one array by calculating approxQuantiles on multiple columns at once, or compute approxQuantiles separately for each column. After doing the performance comparision, we found it’s better to calculating approxQuantiles on multiple columns at once. Here is how we measuring the performance time: ``` var duration = 0.0 for (i<- 0 until 10) { val start = System.nanoTime() discretizer.fit(df) val end = System.nanoTime() duration += (end - start) / 1e9 } println(duration/10) ``` Here is the performance test result: |numCols |NumRows | compute each approxQuantiles separately|compute multiple columns approxQuantiles at one time| |--------|----------|--------------------------------|-------------------------------------------| |10 |60 |0.3623195839 |0.1626658607 | |10 |6000 |0.7537239841 |0.3869370046 | |22 |6000 |1.6497598557 |0.4767903059 | |50 |6000 |3.2268305752 |0.7217818396 | ## How was this patch tested? add UT in QuantileDiscretizerSuite to test multi columns supports Author: Huaxin Gao Closes #19715 from huaxingao/spark_22397. --- .../ml/feature/QuantileDiscretizer.scala | 120 ++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 265 ++++++++++++++++++ 2 files changed, 369 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 95e8830283de..1ec5f8cb6139 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -50,10 +50,28 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + /** + * Array of number of buckets (quantiles, or categories) into which data points are grouped. + * Each value must be greater than or equal to 2 + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * + * @group param + */ + val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of number of buckets " + + "(quantiles, or categories) into which data points are grouped. This is for multiple " + + "columns input. If transforming multiple columns and numBucketsArray is not set, but " + + "numBuckets is set, then numBuckets will be applied across all columns.", + (arrayOfNumBuckets: Array[Int]) => arrayOfNumBuckets.forall(ParamValidators.gtEq(2))) + + /** @group getParam */ + def getNumBucketsArray: Array[Int] = $(numBucketsArray) + /** * Relative error (see documentation for * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for description) * Must be in the range [0, 1]. + * Note that in multiple columns case, relative error is applied to all columns. * default: 0.001 * @group param */ @@ -68,7 +86,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Param for how to handle invalid entries. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special - * additional bucket). + * additional bucket). Note that in the multiple columns case, the invalid handling is applied + * to all columns. That said for 'error' it will throw an error if any invalids are found in + * any column, for 'skip' it will skip rows with any invalids in any columns, etc. * Default: "error" * @group param */ @@ -86,6 +106,11 @@ private[feature] trait QuantileDiscretizerBase extends Params * categorical features. The number of bins can be set using the `numBuckets` parameter. It is * possible that the number of buckets used will be smaller than this value, for example, if there * are too few distinct values of the input to create enough distinct quantiles. + * Since 2.3.0, `QuantileDiscretizer` can map multiple columns at once by setting the `inputCols` + * parameter. If both of the `inputCol` and `inputCols` parameters are set, an Exception will be + * thrown. To specify the number of buckets for each column, the `numBucketsArray` parameter can + * be set, or if the number of buckets should be the same across columns, `numBuckets` can be + * set as a convenience. * * NaN handling: * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This @@ -104,7 +129,8 @@ private[feature] trait QuantileDiscretizerBase extends Params */ @Since("1.6.0") final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable + with HasInputCols with HasOutputCols { @Since("1.6.0") def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -129,34 +155,96 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.1.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.3.0") + def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private[feature] def getInOutCols: (Array[String], Array[String]) = { + require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || + (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), + "QuantileDiscretizer only supports setting either inputCol/outputCol or" + + "inputCols/outputCols." + ) + + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "inputCols number do not match outputCols") + ($(inputCols), $(outputCols)) + } + } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - val inputFields = schema.fields - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val (inputColNames, outputColNames) = getInOutCols + val existingFields = schema.fields + var outputFields = existingFields + inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => + SchemaUtils.checkNumericType(schema, inputColName) + require(existingFields.forall(_.name != outputColName), + s"Output column ${outputColName} already exists.") + val attr = NominalAttribute.defaultAttr.withName(outputColName) + outputFields :+= attr.toStructField() + } StructType(outputFields) } @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { transformSchema(dataset.schema, logging = true) - val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) + if (isSet(inputCols)) { + val splitsArray = if (isSet(numBucketsArray)) { + val probArrayPerCol = $(numBucketsArray).map { numOfBuckets => + (0.0 to 1.0 by 1.0 / numOfBuckets).toArray + } + + val probabilityArray = probArrayPerCol.flatten.sorted.distinct + val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols), + probabilityArray, $(relativeError)) + + splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) => + val probSet = probs.toSet + val idxSet = probabilityArray.zipWithIndex.collect { + case (p, idx) if probSet(p) => + idx + }.toSet + splits.zipWithIndex.collect { + case (s, idx) if idxSet(idx) => + s + } + } + } else { + dataset.stat.approxQuantile($(inputCols), + (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + } + bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits)) + } else { + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + bucketizer.setSplits(getDistinctSplits(splits)) + } + copyValues(bucketizer.setParent(this)) + } + + private def getDistinctSplits(splits: Array[Double]): Array[Double] = { splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val distinctSplits = splits.distinct if (splits.length != distinctSplits.length) { log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid) - .setSplits(distinctSplits.sorted) - .setHandleInvalid($(handleInvalid)) - copyValues(bucketizer.setParent(this)) + distinctSplits.sorted } @Since("1.6.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index f219f775b218..e9a75e931e6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ @@ -146,4 +147,268 @@ class QuantileDiscretizerSuite val model = discretizer.fit(df) assert(model.hasParent) } + + test("Multiple Columns: Test observed number of buckets and their sizes match expected values") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 100000 + val numBuckets = 5 + val data1 = Array.range(1, 100001, 1).map(_.toDouble) + val data2 = Array.range(1, 200000, 2).map(_.toDouble) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + } + + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } + } + + test("Multiple Columns: Test on data with high proportion of duplicated values") { + val spark = this.spark + import spark.implicits._ + + val numBuckets = 5 + val expectedNumBucket = 3 + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 2.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } + } + + test("Multiple Columns: Test transform on data with NaN value") { + val spark = this.spark + import spark.implicits._ + + val numBuckets = 3 + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0) + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + + List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { + case (u, v, w) => + discretizer.setHandleInvalid(u) + val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { + case (((a, b), c), d) => (a, b, c, d) + }.toSeq.toDF("input1", "input2", "expected1", "expected2") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result1", "expected1", "result2", "expected2").collect().foreach { + case Row(x: Double, y: Double, z: Double, w: Double) => + assert(x === y && w === z) + } + } + } + + test("Multiple Columns: Test numBucketsArray") { + val spark = this.spark + import spark.implicits._ + + val numBucketsArray: Array[Int] = Array(2, 5, 10) + val data1 = Array.range(1, 21, 1).map(_.toDouble) + val expected1 = Array (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) + val data2 = Array.range(1, 40, 2).map(_.toDouble) + val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, + 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) + val data3 = Array.range(1, 60, 3).map(_.toDouble) + val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, + 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) + val data = (0 until 20).map { idx => + (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx)) + } + val df = + data.toDF("input1", "input2", "input3", "expected1", "expected2", "expected3") + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBucketsArray(numBucketsArray) + + discretizer.fit(df).transform(df). + select("result1", "expected1", "result2", "expected2", "result3", "expected3") + .collect().foreach { + case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => + assert(r1 === e1, + s"The result value is not correct after bucketing. Expected $e1 but found $r1") + assert(r2 === e2, + s"The result value is not correct after bucketing. Expected $e2 but found $r2") + assert(r3 === e3, + s"The result value is not correct after bucketing. Expected $e3 but found $r3") + } + } + + test("Multiple Columns: Compare single/multiple column(s) QuantileDiscretizer in pipeline") { + val spark = this.spark + import spark.implicits._ + + val numBucketsArray: Array[Int] = Array(2, 5, 10) + val data1 = Array.range(1, 21, 1).map(_.toDouble) + val data2 = Array.range(1, 40, 2).map(_.toDouble) + val data3 = Array.range(1, 60, 3).map(_.toDouble) + val data = (0 until 20).map { idx => + (data1(idx), data2(idx), data3(idx)) + } + val df = + data.toDF("input1", "input2", "input3") + + val multiColsDiscretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBucketsArray(numBucketsArray) + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsDiscretizer)) + .fit(df) + + val discretizerForCol1 = new QuantileDiscretizer() + .setInputCol("input1") + .setOutputCol("result1") + .setNumBuckets(numBucketsArray(0)) + + val discretizerForCol2 = new QuantileDiscretizer() + .setInputCol("input2") + .setOutputCol("result2") + .setNumBuckets(numBucketsArray(1)) + + val discretizerForCol3 = new QuantileDiscretizer() + .setInputCol("input3") + .setOutputCol("result3") + .setNumBuckets(numBucketsArray(2)) + + val plForSingleCol = new Pipeline() + .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) + .fit(df) + + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "result2", "result3") + .collect() + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "result2", "result3") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && + rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) + } + } + + test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + + "explicitly with identical values") { + val spark = this.spark + import spark.implicits._ + + val data1 = Array.range(1, 21, 1).map(_.toDouble) + val data2 = Array.range(1, 40, 2).map(_.toDouble) + val data3 = Array.range(1, 60, 3).map(_.toDouble) + val data = (0 until 20).map { idx => + (data1(idx), data2(idx), data3(idx)) + } + val df = + data.toDF("input1", "input2", "input3") + + val discretizerSingleNumBuckets = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBuckets(10) + + val discretizerNumBucketsArray = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBucketsArray(Array(10, 10, 10)) + + val result1 = discretizerSingleNumBuckets.fit(df).transform(df) + .select("result1", "result2", "result3") + .collect() + val result2 = discretizerNumBucketsArray.fit(df).transform(df) + .select("result1", "result2", "result3") + .collect() + + result1.zip(result2).foreach { + case (row1, row2) => + assert(row1.getDouble(0) == row2.getDouble(0) && + row1.getDouble(1) == row2.getDouble(1) && + row1.getDouble(2) == row2.getDouble(2)) + } + } + + test("Multiple Columns: read/write") { + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBucketsArray(Array(5, 10)) + testDefaultReadWrite(discretizer) + } + + test("Multiple Columns: Both inputCol and inputCols are set") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(3) + .setInputCols(Array("input1", "input2")) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + // When both inputCol and inputCols are set, we throw Exception. + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + + test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(3) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } } From 028ee40165315337e2a349b19731764d64e4f51d Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Sun, 31 Dec 2017 14:51:38 +0200 Subject: [PATCH 059/100] [SPARK-22801][ML][PYSPARK] Allow FeatureHasher to treat numeric columns as categorical Previously, `FeatureHasher` always treats numeric type columns as numbers and never as categorical features. It is quite common to have categorical features represented as numbers or codes in data sources. In order to hash these features as categorical, users must first explicitly convert them to strings which is cumbersome. Add a new param `categoricalCols` which specifies the numeric columns that should be treated as categorical features. ## How was this patch tested? New unit tests. Author: Nick Pentreath Closes #19991 from MLnick/hasher-num-cat. --- .../spark/ml/feature/FeatureHasher.scala | 38 +++++++++++++++---- .../spark/ml/feature/FeatureHasherSuite.scala | 25 ++++++++++++ python/pyspark/ml/feature.py | 34 +++++++++++++---- 3 files changed, 83 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index 4615daed20fb..a918dd4c075d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} @@ -40,9 +40,9 @@ import org.apache.spark.util.collection.OpenHashMap * The [[FeatureHasher]] transformer operates on multiple columns. Each column may contain either * numeric or categorical features. Behavior and handling of column data types is as follows: * -Numeric columns: For numeric features, the hash value of the column name is used to map the - * feature value to its index in the feature vector. Numeric features are never - * treated as categorical, even when they are integers. You must explicitly - * convert numeric columns containing categorical features to strings first. + * feature value to its index in the feature vector. By default, numeric features + * are not treated as categorical (even when they are integers). To treat them + * as categorical, specify the relevant columns in `categoricalCols`. * -String columns: For categorical features, the hash value of the string "column_name=value" * is used to map to the vector index, with an indicator value of `1.0`. * Thus, categorical features are "one-hot" encoded @@ -86,6 +86,17 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") def this() = this(Identifiable.randomUID("featureHasher")) + /** + * Numeric columns to treat as categorical features. By default only string and boolean + * columns are treated as categorical, so this param can be used to explicitly specify the + * numerical columns to treat as categorical. Note, the relevant columns must also be set in + * `inputCols`. + * @group param + */ + @Since("2.3.0") + val categoricalCols = new StringArrayParam(this, "categoricalCols", + "numeric columns to treat as categorical") + /** * Number of features. Should be greater than 0. * (default = 2^18^) @@ -117,15 +128,28 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group getParam */ + @Since("2.3.0") + def getCategoricalCols: Array[String] = $(categoricalCols) + + /** @group setParam */ + @Since("2.3.0") + def setCategoricalCols(value: Array[String]): this.type = set(categoricalCols, value) + @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { val hashFunc: Any => Int = OldHashingTF.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) + val catCols = if (isSet(categoricalCols)) { + $(categoricalCols).toSet + } else { + Set[String]() + } val outputSchema = transformSchema(dataset.schema) val realFields = outputSchema.fields.filter { f => - f.dataType.isInstanceOf[NumericType] + f.dataType.isInstanceOf[NumericType] && !catCols.contains(f.name) }.map(_.name).toSet def getDouble(x: Any): Double = { @@ -149,8 +173,8 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme val hash = hashFunc(colName) (hash, value) } else { - // string and boolean values are treated as categorical, with an indicator value of 1.0 - // and vector index based on hash of "column_name=value" + // string, boolean and numeric values that are in catCols are treated as categorical, + // with an indicator value of 1.0 and vector index based on hash of "column_name=value" val value = row.get(fieldIndex).toString val fieldName = s"$colName=$value" val hash = hashFunc(fieldName) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 407371a82666..3fc3cbb62d5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -78,6 +78,31 @@ class FeatureHasherSuite extends SparkFunSuite assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) } + test("setting explicit numerical columns to treat as categorical") { + val df = Seq( + (2.0, 1, "foo"), + (3.0, 2, "bar") + ).toDF("real", "int", "string") + + val n = 100 + val hasher = new FeatureHasher() + .setInputCols("real", "int", "string") + .setCategoricalCols(Array("real")) + .setOutputCol("features") + .setNumFeatures(n) + val output = hasher.transform(df) + + val features = output.select("features").as[Vector].collect() + // Assume perfect hash on field names + def idx: Any => Int = murmur3FeatureIdx(n) + // check expected indices + val expected = Seq( + Vectors.sparse(n, Seq((idx("real=2.0"), 1.0), (idx("int"), 1.0), (idx("string=foo"), 1.0))), + Vectors.sparse(n, Seq((idx("real=3.0"), 1.0), (idx("int"), 2.0), (idx("string=bar"), 1.0))) + ) + assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + } + test("hashing works for all numeric types") { val df = Seq(5.0, 10.0, 15.0).toDF("real") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5094324e5c1f..13bf95cce40b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -714,9 +714,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, * Numeric columns: For numeric features, the hash value of the column name is used to map the - feature value to its index in the feature vector. Numeric features are never - treated as categorical, even when they are integers. You must explicitly - convert numeric columns containing categorical features to strings first. + feature value to its index in the feature vector. By default, numeric features + are not treated as categorical (even when they are integers). To treat them + as categorical, specify the relevant columns in `categoricalCols`. * String columns: For categorical features, the hash value of the string "column_name=value" @@ -741,6 +741,8 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + >>> hasher.setCategoricalCols(["real"]).transform(df).head().features + SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) @@ -752,10 +754,14 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, .. versionadded:: 2.3.0 """ + categoricalCols = Param(Params._dummy(), "categoricalCols", + "numeric columns to treat as categorical", + typeConverter=TypeConverters.toListString) + @keyword_only - def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None): """ - __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None) """ super(FeatureHasher, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid) @@ -765,14 +771,28 @@ def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): @keyword_only @since("2.3.0") - def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None): """ - setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None) Sets params for this FeatureHasher. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.3.0") + def setCategoricalCols(self, value): + """ + Sets the value of :py:attr:`categoricalCols`. + """ + return self._set(categoricalCols=value) + + @since("2.3.0") + def getCategoricalCols(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.categoricalCols) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, From 5955a2d0fb8f84560925d72afe36055089f44a05 Mon Sep 17 00:00:00 2001 From: Jirka Kremser Date: Sun, 31 Dec 2017 15:38:10 -0600 Subject: [PATCH 060/100] [MINOR][DOCS] s/It take/It takes/g ## What changes were proposed in this pull request? Fixing three small typos in the docs, in particular: It take a `RDD` -> It takes an `RDD` (twice) It take an `JavaRDD` -> It takes a `JavaRDD` I didn't create any Jira issue for this minor thing, I hope it's ok. ## How was this patch tested? visually by clicking on 'preview' Author: Jirka Kremser Closes #20108 from Jiri-Kremser/docs-typo. --- docs/mllib-frequent-pattern-mining.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index c9cd7cc85e75..0d3192c6b1d9 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -41,7 +41,7 @@ We refer users to the papers for more details. [`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take a `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +It takes an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. The following @@ -60,7 +60,7 @@ Refer to the [`FPGrowth` Scala docs](api/scala/index.html#org.apache.spark.mllib [`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the FP-growth algorithm. -It take an `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +It takes a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) that stores the frequent itemsets with their frequencies. The following @@ -79,7 +79,7 @@ Refer to the [`FPGrowth` Java docs](api/java/org/apache/spark/mllib/fpm/FPGrowth [`FPGrowth`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take an `RDD` of transactions, where each transaction is an `List` of items of a generic type. +It takes an `RDD` of transactions, where each transaction is an `List` of items of a generic type. Calling `FPGrowth.train` with transactions returns an [`FPGrowthModel`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. From 994065d891a23ed89a09b3f95bc3f1f986793e0d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 Dec 2017 15:28:59 -0800 Subject: [PATCH 061/100] [SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator ## What changes were proposed in this pull request? This patch adds a new class `OneHotEncoderEstimator` which extends `Estimator`. The `fit` method returns `OneHotEncoderModel`. Common methods between existing `OneHotEncoder` and new `OneHotEncoderEstimator`, such as transforming schema, are extracted and put into `OneHotEncoderCommon` to reduce code duplication. ### Multi-column support `OneHotEncoderEstimator` adds simpler multi-column support because it is new API and can be free from backward compatibility. ### handleInvalid Param support `OneHotEncoderEstimator` supports `handleInvalid` Param. It supports `error` and `keep`. ## How was this patch tested? Added new test suite `OneHotEncoderEstimatorSuite`. Author: Liang-Chi Hsieh Closes #19527 from viirya/SPARK-13030. --- .../spark/ml/feature/OneHotEncoder.scala | 83 +-- .../ml/feature/OneHotEncoderEstimator.scala | 522 ++++++++++++++++++ .../feature/OneHotEncoderEstimatorSuite.scala | 421 ++++++++++++++ 3 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index a669da183e2c..5ab6c2dde667 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The output vectors are sparse. * * @see `StringIndexer` for converting categorical values into category indices + * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder` + * will be removed in 3.0.0. */ @Since("1.4.0") +@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" + + " will be removed in 3.0.0.", "2.3.0") class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { @@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val outputColName = $(outputCol) + val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - val inputFields = schema.fields require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") - val inputAttr = Attribute.fromStructField(schema(inputColName)) - val outputAttrNames: Option[Array[String]] = inputAttr match { - case nominal: NominalAttribute => - if (nominal.values.isDefined) { - nominal.values - } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(_.toString)) - } else { - None - } - case binary: BinaryAttribute => - if (binary.values.isDefined) { - binary.values - } else { - Some(Array.tabulate(2)(_.toString)) - } - case _: NumericAttribute => - throw new RuntimeException( - s"The input column $inputColName cannot be numeric.") - case _ => - None // optimistic about unknown attributes - } - - val filteredOutputAttrNames = outputAttrNames.map { names => - if ($(dropLast)) { - require(names.length > 1, - s"The input column $inputColName should have at least two distinct values.") - names.dropRight(1) - } else { - names - } - } - - val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { - val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => - BinaryAttribute.defaultAttr.withName(name) - } - new AttributeGroup($(outputCol), attrs) - } else { - new AttributeGroup($(outputCol)) - } - - val outputFields = inputFields :+ outputAttrGroup.toStructField() + val outputField = OneHotEncoderCommon.transformOutputColumnSchema( + schema(inputColName), outputColName, $(dropLast)) + val outputFields = inputFields :+ outputField StructType(outputFields) } @@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) - val shouldDropLast = $(dropLast) - var outputAttrGroup = AttributeGroup.fromStructField( + + val outputAttrGroupFromSchema = AttributeGroup.fromStructField( transformSchema(dataset.schema)(outputColName)) - if (outputAttrGroup.size < 0) { - // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) - .treeAggregate(0.0)( - (m, x) => { - assert(x <= Int.MaxValue, - s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") - assert(x >= 0.0 && x == x.toInt, - s"Values from column $inputColName must be indices, but got $x.") - math.max(m, x) - }, - (m0, m1) => { - math.max(m0, m1) - } - ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(_.toString) - val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames - val outputAttrs: Array[Attribute] = - filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) - outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + + val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0) + } else { + outputAttrGroupFromSchema } + val metadata = outputAttrGroup.toMetadata() // data transformation diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala new file mode 100644 index 000000000000..074622d41e28 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} + +/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ +private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid + with HasInputCols with HasOutputCols { + + /** + * Param for how to handle invalid data. + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). + * Default: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error).", + ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) + + setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) + + /** + * Whether to drop the last category in the encoded vector (default: true) + * @group param + */ + @Since("2.3.0") + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) + + /** @group getParam */ + @Since("2.3.0") + def getDropLast: Boolean = $(dropLast) + + protected def validateAndTransformSchema( + schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val existingFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + // Input columns must be NumericType. + inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) + + // Prepares output columns with proper attributes by examining input columns. + val inputFields = $(inputCols).map(schema(_)) + + val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => + OneHotEncoderCommon.transformOutputColumnSchema( + inputField, outputColName, dropLast, keepInvalid) + } + outputFields.foldLeft(schema) { case (newSchema, outputField) => + SchemaUtils.appendColumn(newSchema, outputField) + } + } +} + +/** + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via `dropLast`), + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + * vector. + * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * + * @see `StringIndexer` for converting categorical values into category indices + */ +@Since("2.3.0") +class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String) + extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable { + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("oneHotEncoder")) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + } + + @Since("2.3.0") + override def fit(dataset: Dataset[_]): OneHotEncoderModel = { + transformSchema(dataset.schema) + + // Compute the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, + keepInvalid = false) + val categorySizes = new Array[Int]($(outputCols).length) + + val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val numOfAttrs = AttributeGroup.fromStructField( + transformedSchema(outputColName)).size + if (numOfAttrs < 0) { + Some(idx) + } else { + categorySizes(idx) = numOfAttrs + None + } + } + + // Some input columns don't have attributes or their attributes don't have necessary info. + // We need to scan the data to get the number of values for each column. + if (columnToScanIndices.length > 0) { + val inputColNames = columnToScanIndices.map($(inputCols)(_)) + val outputColNames = columnToScanIndices.map($(outputCols)(_)) + + // When fitting data, we want the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, inputColNames, outputColNames, dropLast = false) + attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => + categorySizes(idx) = attrGroup.size + } + } + + val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) + copyValues(model) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra) +} + +@Since("2.3.0") +object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] { + + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + + @Since("2.3.0") + override def load(path: String): OneHotEncoderEstimator = super.load(path) +} + +@Since("2.3.0") +class OneHotEncoderModel private[ml] ( + @Since("2.3.0") override val uid: String, + @Since("2.3.0") val categorySizes: Array[Int]) + extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable { + + import OneHotEncoderModel._ + + // Returns the category size for a given index with `dropLast` and `handleInvalid` + // taken into account. + private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + val dropLast = getDropLast + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + if (!dropLast && keepInvalid) { + // When `handleInvalid` is "keep", an extra category is added as last category + // for invalid data. + orgCategorySize + 1 + } else if (dropLast && !keepInvalid) { + // When `dropLast` is true, the last category is removed. + orgCategorySize - 1 + } else { + // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid + // data is removed. Thus, it is the same as the plain number of categories. + orgCategorySize + } + } + + private def encoder: UserDefinedFunction = { + val oneValue = Array(1.0) + val emptyValues = Array.empty[Double] + val emptyIndices = Array.empty[Int] + val dropLast = getDropLast + val handleInvalid = getHandleInvalid + val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + // The udf performed on input data. The first parameter is the input value. The second + // parameter is the index of input. + udf { (label: Double, idx: Int) => + val plainNumCategories = categorySizes(idx) + val size = configedCategorySize(plainNumCategories, idx) + + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative.") + } else if (label == size && dropLast && !keepInvalid) { + // When `dropLast` is true and `handleInvalid` is not "keep", + // the last category is removed. + Vectors.sparse(size, emptyIndices, emptyValues) + } else if (label >= plainNumCategories && keepInvalid) { + // When `handleInvalid` is "keep", encodes invalid data to last category (and removed + // if `dropLast` is true) + if (dropLast) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else { + Vectors.sparse(size, Array(size - 1), oneValue) + } + } else if (label < plainNumCategories) { + Vectors.sparse(size, Array(label.toInt), oneValue) + } else { + assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } + } + } + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + + require(inputColNames.length == categorySizes.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"features ${categorySizes.length} during fitting.") + + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + verifyNumOfValues(transformedSchema) + } + + /** + * If the metadata of input columns also specifies the number of categories, we need to + * compare with expected category number with `handleInvalid` and `dropLast` taken into + * account. Mismatched numbers will cause exception. + */ + private def verifyNumOfValues(schema: StructType): StructType = { + $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = $(inputCols)(idx) + val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) + + // If the input metadata specifies number of category for output column, + // comparing with expected category number with `handleInvalid` and + // `dropLast` taken into account. + if (attrGroup.attributes.nonEmpty) { + val numCategories = configedCategorySize(categorySizes(idx), idx) + require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + + s"$numCategories categorical values for input column ${inputColName}, " + + s"but the input column had metadata specifying ${attrGroup.size} values.") + } + } + schema + } + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val transformedSchema = transformSchema(dataset.schema, logging = true) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + + val encodedColumns = (0 until $(inputCols).length).map { idx => + val inputColName = $(inputCols)(idx) + val outputColName = $(outputCols)(idx) + + val outputAttrGroupFromSchema = + AttributeGroup.fromStructField(transformedSchema(outputColName)) + + val metadata = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, + categorySizes(idx), $(dropLast), keepInvalid).toMetadata() + } else { + outputAttrGroupFromSchema.toMetadata() + } + + encoder(col(inputColName).cast(DoubleType), lit(idx)) + .as(outputColName, metadata) + } + dataset.withColumns($(outputCols), encodedColumns) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderModel = { + val copied = new OneHotEncoderModel(uid, categorySizes) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.3.0") + override def write: MLWriter = new OneHotEncoderModelWriter(this) +} + +@Since("2.3.0") +object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { + + private[OneHotEncoderModel] + class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { + + private case class Data(categorySizes: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.categorySizes) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] { + + private val className = classOf[OneHotEncoderModel].getName + + override def load(path: String): OneHotEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("categorySizes") + .head() + val categorySizes = data.getAs[Seq[Int]](0).toArray + val model = new OneHotEncoderModel(metadata.uid, categorySizes) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.3.0") + override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader + + @Since("2.3.0") + override def load(path: String): OneHotEncoderModel = super.load(path) +} + +/** + * Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`. + */ +private[feature] object OneHotEncoderCommon { + + private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = { + val inputAttr = Attribute.fromStructField(inputCol) + inputAttr match { + case nominal: NominalAttribute => + if (nominal.values.isDefined) { + nominal.values + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values + } else { + Some(Array.tabulate(2)(_.toString)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column ${inputCol.name} cannot be continuous-value.") + case _ => + None // optimistic about unknown attributes + } + } + + /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */ + private def genOutputAttrGroup( + outputAttrNames: Option[Array[String]], + outputColName: String): AttributeGroup = { + outputAttrNames.map { attrNames => + val attrs: Array[Attribute] = attrNames.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup(outputColName, attrs) + }.getOrElse{ + new AttributeGroup(outputColName) + } + } + + /** + * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column. + */ + def transformOutputColumnSchema( + inputCol: StructField, + outputColName: String, + dropLast: Boolean, + keepInvalid: Boolean = false): StructField = { + val outputAttrNames = genOutputAttrNames(inputCol) + val filteredOutputAttrNames = outputAttrNames.map { names => + if (dropLast && !keepInvalid) { + require(names.length > 1, + s"The input column ${inputCol.name} should have at least two distinct values.") + names.dropRight(1) + } else if (!dropLast && keepInvalid) { + names ++ Seq("invalidValues") + } else { + names + } + } + + genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField() + } + + /** + * This method is called when we want to generate `AttributeGroup` from actual data for + * one-hot encoder. + */ + def getOutputAttrGroupFromData( + dataset: Dataset[_], + inputColNames: Seq[String], + outputColNames: Seq[String], + dropLast: Boolean): Seq[AttributeGroup] = { + // The RDD approach has advantage of early-stop if any values are invalid. It seems that + // DataFrame ops don't have equivalent functions. + val columns = inputColNames.map { inputColName => + col(inputColName).cast(DoubleType) + } + val numOfColumns = columns.length + + val numAttrsArray = dataset.select(columns: _*).rdd.map { row => + (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray + }.treeAggregate(new Array[Double](numOfColumns))( + (maxValues, curValues) => { + (0 until numOfColumns).foreach { idx => + val x = curValues(idx) + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") + assert(x >= 0.0 && x == x.toInt, + s"Values from column ${inputColNames(idx)} must be indices, but got $x.") + maxValues(idx) = math.max(maxValues(idx), x) + } + maxValues + }, + (m0, m1) => { + (0 until numOfColumns).foreach { idx => + m0(idx) = math.max(m0(idx), m1(idx)) + } + m0 + } + ).map(_.toInt + 1) + + outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => + createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false) + } + } + + /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ + def createAttrGroupForAttrNames( + outputColName: String, + numAttrs: Int, + dropLast: Boolean, + keepInvalid: Boolean): AttributeGroup = { + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) + val filtered = if (dropLast && !keepInvalid) { + outputAttrNames.dropRight(1) + } else if (!dropLast && keepInvalid) { + outputAttrNames ++ Seq("invalidValues") + } else { + outputAttrNames + } + genOutputAttrGroup(Some(filtered), outputColName) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala new file mode 100644 index 000000000000..1d3f84558642 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +class OneHotEncoderEstimatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("params") { + ParamsSuite.checkParams(new OneHotEncoderEstimator) + } + + test("OneHotEncoderEstimator dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderEstimator dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } + + test("read/write") { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + testDefaultReadWrite(encoder) + } + + test("OneHotEncoderModel read/write") { + val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.categorySizes === instance.categorySizes) + } + + test("OneHotEncoderEstimator with varying types") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val dfWithTypes = df + .withColumn("shortInput", df("input").cast(ShortType)) + .withColumn("longInput", df("input").cast(LongType)) + .withColumn("intInput", df("input").cast(IntegerType)) + .withColumn("floatInput", df("input").cast(FloatType)) + .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) + + val cols = Array("input", "shortInput", "longInput", "intInput", + "floatInput", "decimalInput") + for (col <- cols) { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array(col)) + .setOutputCols(Array("output")) + .setDropLast(false) + + val model = encoder.fit(dfWithTypes) + val encoded = model.transform(dfWithTypes) + + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())), + Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } + } + + test("Throw error on invalid values") { + val trainingData = Seq((0, 0), (1, 1), (2, 2)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2), (1, 3)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).show + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + } + + test("Can't transform on negative input") { + val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b") + val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Negative value: -1.0. Input can't be negative") + } + + test("Keep on invalid values: dropLast = false") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(false) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("Keep on invalid values: dropLast = true") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(3, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(true) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes dropLast") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected1", new VectorUDT), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + + model.setDropLast(false) + val encoded1 = model.transform(df) + encoded1.select("output", "expected1").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + + model.setDropLast(true) + val encoded2 = model.transform(df) + encoded2.select("output", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes handleInvalid") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(trainingDF) + model.setHandleInvalid("error") + + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + + model.setHandleInvalid("keep") + model.transform(testDF).collect() + } + + test("Transforming on mismatched attributes") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + + val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") + val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", testAttr.toMetadata())) + val err = intercept[Exception] { + model.transform(testDF).collect() + } + err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + } +} From f5b7714e0eed4530519df97eb4faca8b05dde161 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 Jan 2018 08:47:12 -0600 Subject: [PATCH 062/100] [BUILD] Close stale PRs Closes #18916 Closes #19520 Closes #19613 Closes #19739 Closes #19936 Closes #19919 Closes #19933 Closes #19917 Closes #20027 Closes #19035 Closes #20044 Closes #20104 Author: Sean Owen Closes #20130 from srowen/StalePRs. From 7a702d8d5ed830de5d2237f136b08bd18deae037 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 2 Jan 2018 07:00:31 +0900 Subject: [PATCH 063/100] [SPARK-21616][SPARKR][DOCS] update R migration guide and vignettes ## What changes were proposed in this pull request? update R migration guide and vignettes ## How was this patch tested? manually Author: Felix Cheung Closes #20106 from felixcheung/rreleasenote23. --- R/pkg/tests/fulltests/test_Windows.R | 1 + R/pkg/vignettes/sparkr-vignettes.Rmd | 3 +-- docs/sparkr.md | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/R/pkg/tests/fulltests/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R index b2ec6c67311d..209827d9fdc2 100644 --- a/R/pkg/tests/fulltests/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 8c4ea2f2db18..2e662424b25f 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -391,8 +391,7 @@ We convert `mpg` to `kmpg` (kilometers per gallon). `carsSubDF` is a `SparkDataF ```{r} carsSubDF <- select(carsDF, "model", "mpg") -schema <- structType(structField("model", "string"), structField("mpg", "double"), - structField("kmpg", "double")) +schema <- "model STRING, mpg DOUBLE, kmpg DOUBLE" out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` diff --git a/docs/sparkr.md b/docs/sparkr.md index a3254e765413..997ea60fb6cf 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -657,3 +657,9 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`. - `spark.lda` was not setting the optimizer correctly. It has been corrected. - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`. + +## Upgrading to SparkR 2.3.0 + + - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. + - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. + - A warning can be raised if versions of SparkR package and the Spark JVM do not match. From c284c4e1f6f684ca8db1cc446fdcc43b46e3413c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 31 Dec 2017 17:00:41 -0600 Subject: [PATCH 064/100] [MINOR] Fix a bunch of typos --- bin/find-spark-home | 2 +- .../java/org/apache/spark/util/kvstore/LevelDBIterator.java | 2 +- .../org/apache/spark/network/protocol/MessageWithHeader.java | 4 ++-- .../java/org/apache/spark/network/sasl/SaslEncryption.java | 4 ++-- .../org/apache/spark/network/util/TransportFrameDecoder.java | 2 +- .../network/shuffle/ExternalShuffleBlockResolverSuite.java | 2 +- .../main/java/org/apache/spark/util/sketch/BloomFilter.java | 2 +- .../java/org/apache/spark/unsafe/array/ByteArrayMethods.java | 2 +- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- core/src/main/scala/org/apache/spark/status/storeTypes.scala | 2 +- .../test/scala/org/apache/spark/util/FileAppenderSuite.scala | 2 +- dev/github_jira_sync.py | 2 +- dev/lint-python | 2 +- examples/src/main/python/ml/linearsvc.py | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 2 +- .../apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java | 2 +- .../org/apache/spark/streaming/kinesis/KinesisUtils.scala | 4 ++-- .../java/org/apache/spark/launcher/ChildProcAppHandle.java | 2 +- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 +- python/pyspark/ml/image.py | 2 +- .../org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 2 +- .../spark/sql/catalyst/expressions/UnsafeArrayData.java | 2 +- .../scala/org/apache/spark/sql/catalyst/analysis/view.scala | 2 +- .../spark/sql/catalyst/expressions/objects/objects.scala | 5 +++-- .../expressions/aggregate/CountMinSketchAggSuite.scala | 2 +- .../sql/sources/v2/streaming/MicroBatchWriteSupport.java | 2 +- .../apache/spark/sql/execution/ui/static/spark-sql-viz.css | 2 +- .../apache/spark/sql/execution/datasources/FileFormat.scala | 2 +- .../spark/sql/execution/datasources/csv/CSVInferSchema.scala | 2 +- .../apache/spark/sql/execution/joins/HashedRelation.scala | 2 +- .../streaming/StreamingSymmetricHashJoinHelper.scala | 2 +- .../apache/spark/sql/execution/ui/SQLAppStatusListener.scala | 2 +- .../scala/org/apache/spark/sql/expressions/Aggregator.scala | 2 +- .../main/scala/org/apache/spark/sql/streaming/progress.scala | 2 +- .../src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java | 2 +- .../inputs/typeCoercion/native/implicitTypeCasts.sql | 2 +- .../execution/streaming/CompactibleFileStreamLogSuite.scala | 2 +- .../org/apache/spark/sql/sources/fakeExternalSources.scala | 2 +- .../org/apache/spark/sql/streaming/FileStreamSinkSuite.scala | 2 +- .../test/scala/org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/HiveShim.scala | 2 +- sql/hive/src/test/resources/data/conf/hive-log4j.properties | 2 +- .../org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 2 +- .../scala/org/apache/spark/streaming/util/StateMap.scala | 2 +- 45 files changed, 50 insertions(+), 49 deletions(-) diff --git a/bin/find-spark-home b/bin/find-spark-home index fa78407d4175..617dbaa4fff8 100755 --- a/bin/find-spark-home +++ b/bin/find-spark-home @@ -21,7 +21,7 @@ FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" -# Short cirtuit if the user already has this set. +# Short circuit if the user already has this set. if [ ! -z "${SPARK_HOME}" ]; then exit 0 elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index b3ba76ba5805..f62e85d43531 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -86,7 +86,7 @@ class LevelDBIterator implements KVStoreIterator { end = index.start(parent, params.last); } if (it.hasNext()) { - // When descending, the caller may have set up the start of iteration at a non-existant + // When descending, the caller may have set up the start of iteration at a non-existent // entry that is guaranteed to be after the desired entry. For example, if you have a // compound key (a, b) where b is a, integer, you may seek to the end of the elements that // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 897d0f9e4fb8..a5337656cbd8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -47,7 +47,7 @@ class MessageWithHeader extends AbstractFileRegion { /** * When the write buffer size is larger than this limit, I/O will be done in chunks of this size. * The size should not be too large as it will waste underlying memory copy. e.g. If network - * avaliable buffer is smaller than this limit, the data cannot be sent within one single write + * available buffer is smaller than this limit, the data cannot be sent within one single write * operation while it still will make memory copy with this size. */ private static final int NIO_BUFFER_LIMIT = 256 * 1024; @@ -100,7 +100,7 @@ public long transferred() { * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. * * The contract is that the caller will ensure position is properly set to the total number - * of bytes transferred so far (i.e. value returned by transfered()). + * of bytes transferred so far (i.e. value returned by transferred()). */ @Override public long transferTo(final WritableByteChannel target, final long position) throws IOException { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 16ab4efcd4f5..3ac9081d78a7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -38,7 +38,7 @@ import org.apache.spark.network.util.NettyUtils; /** - * Provides SASL-based encription for transport channels. The single method exposed by this + * Provides SASL-based encryption for transport channels. The single method exposed by this * class installs the needed channel handlers on a connected channel. */ class SaslEncryption { @@ -166,7 +166,7 @@ static class EncryptedMessage extends AbstractFileRegion { * This makes assumptions about how netty treats FileRegion instances, because there's no way * to know beforehand what will be the size of the encrypted message. Namely, it assumes * that netty will try to transfer data from this message while - * transfered() < count(). So these two methods return, technically, wrong data, + * transferred() < count(). So these two methods return, technically, wrong data, * but netty doesn't know better. */ @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 50d9651ccbbb..8e73ab077a5c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -29,7 +29,7 @@ /** * A customized frame decoder that allows intercepting raw data. *

- * This behaves like Netty's frame decoder (with harcoded parameters that match this library's + * This behaves like Netty's frame decoder (with hard coded parameters that match this library's * needs), except it allows an interceptor to be installed to read data directly before it's * framed. *

diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 23438a08fa09..6d201b8fe8d7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -127,7 +127,7 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); assertEquals(parsedShuffleInfo, shuffleInfo); - // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // Intentionally keep these hard-coded strings in here, to check backwards-compatibility. // its not legacy yet, but keeping this here in case anybody changes it String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index c0b425e72959..37803c7a3b10 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -34,7 +34,7 @@ *

  • {@link String}
  • * * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that - * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that has * not actually been put in the {@code BloomFilter}. * * The implementation is largely based on the {@code BloomFilter} class from Guava. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index f121b1cd745b..a6b1f7a16d60 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -66,7 +66,7 @@ public static boolean arrayEquals( i += 1; } } - // for architectures that suport unaligned accesses, chew it up 8 bytes at a time + // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { while (i <= length - 8) { if (Platform.getLong(leftBase, leftOffset + i) != diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1dbb3789b0c9..31f3cb9dfa0a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2276,7 +2276,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Clean a closure to make it ready to serialized and send to tasks + * Clean a closure to make it ready to be serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index d9ead0071d3b..1cfd30df4909 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -61,7 +61,7 @@ private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { /** * Keep track of the existing stages when the job was submitted, and those that were - * completed during the job's execution. This allows a more accurate acounting of how + * completed during the job's execution. This allows a more accurate accounting of how * many tasks were skipped for the job. */ private[spark] class JobDataWrapper( diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index cd0ed5b036bf..52cd5378bc71 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -356,7 +356,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { generatedFiles } - /** Delete all the generated rolledover files */ + /** Delete all the generated rolled over files */ def cleanup() { testFile.getParentFile.listFiles.filter { file => file.getName.startsWith(testFile.getName) diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index acc9aeabbb9f..7df7b325cabe 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -43,7 +43,7 @@ # "notification overload" when running for the first time. MIN_COMMENT_PR = int(os.environ.get("MIN_COMMENT_PR", "1496")) -# File used as an opitimization to store maximum previously seen PR +# File used as an optimization to store maximum previously seen PR # Used mostly because accessing ASF JIRA is slow, so we want to avoid checking # the state of JIRA's that are tied to PR's we've already looked at. MAX_FILE = ".github-jira-max" diff --git a/dev/lint-python b/dev/lint-python index 07e2606d4514..df8df037a5f6 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -# Exclude auto-geneated configuration file. +# Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" diff --git a/examples/src/main/python/ml/linearsvc.py b/examples/src/main/python/ml/linearsvc.py index 18cbf87a1069..9b79abbf96f8 100644 --- a/examples/src/main/python/ml/linearsvc.py +++ b/examples/src/main/python/ml/linearsvc.py @@ -37,7 +37,7 @@ # Fit the model lsvcModel = lsvc.fit(training) - # Print the coefficients and intercept for linearsSVC + # Print the coefficients and intercept for linear SVC print("Coefficients: " + str(lsvcModel.coefficients)) print("Intercept: " + str(lsvcModel.intercept)) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 9d9e2aaba807..66b3409c0cd0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsuemrs (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 08db4d827e40..75245943c493 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -201,7 +201,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L verifyTopicDeletionWithRetries(zkUtils, topic, partitions, List(this.server)) } - /** Add new paritions to a Kafka topic */ + /** Add new partitions to a Kafka topic */ def addPartitions(topic: String, partitions: Int): Unit = { AdminUtils.addPartitions(zkUtils, topic, partitions) // wait until metadata is propagated diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java index 87bfe1514e33..b20fad229126 100644 --- a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java @@ -111,7 +111,7 @@ public String call(ConsumerRecord r) { LocationStrategies.PreferConsistent() ).map(handler); - // just making sure the java user apis work; the scala tests handle logic corner cases + // just making sure the java user APIs work; the scala tests handle logic corner cases long count1 = rdd1.count(); long count2 = rdd2.count(); Assert.assertTrue(count1 > 0); diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2500460bd330..c60b9896a347 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -164,7 +164,7 @@ object KinesisUtils { * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from * Kinesis stream. - * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * @param stsSessionName Name to uniquely identify STS sessions if multiple principals assume * the same role. * @param stsExternalId External ID that can be used to validate against the assumed IAM role's * trust policy. @@ -434,7 +434,7 @@ object KinesisUtils { * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from * Kinesis stream. - * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * @param stsSessionName Name to uniquely identify STS sessions if multiple princpals assume * the same role. * @param stsExternalId External ID that can be used to validate against the assumed IAM role's * trust policy. diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 3bb7e12385fd..8b3f427b7750 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -71,7 +71,7 @@ void setChildProc(Process childProc, String loggerName, InputStream logStream) { } /** - * Wait for the child process to exit and update the handle's state if necessary, accoding to + * Wait for the child process to exit and update the handle's state if necessary, according to * the exit code. */ void monitorChild() { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 0130b3e255f0..095b54c0fe83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -94,7 +94,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 384599dc0c53..c9b840276f67 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -212,7 +212,7 @@ def readImages(self, path, recursive=False, numPartitions=-1, ImageSchema = _ImageSchema() -# Monkey patch to disallow instantization of this class. +# Monkey patch to disallow instantiation of this class. def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index ab0005d7b53a..061f653b97b7 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -95,7 +95,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", "spark.executor.instances" -> "2", - // Sending some senstive information, which we'll make sure gets redacted + // Sending some sensitive information, which we'll make sure gets redacted "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 64ab01ca5740..d18542b188f7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -294,7 +294,7 @@ public void setNullAt(int ordinal) { assertIndexIsValid(ordinal); BitSetMethods.set(baseObject, baseOffset + 8, ordinal); - /* we assume the corrresponding column was already 0 or + /* we assume the corresponding column was already 0 or will be set to 0 later by the caller side */ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bbe41cf8f15..20216087b015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf * view resolution, in this way, we are able to get the correct view column ordering and * omit the extra columns that we don't require); * 1.2. Else set the child output attributes to `queryOutput`. - * 2. Map the `queryQutput` to view output by index, if the corresponding attributes don't match, + * 2. Map the `queryOutput` to view output by index, if the corresponding attributes don't match, * try to up cast and alias the attribute in `queryOutput` to the attribute in the view output. * 3. Add a Project over the child, with the new output generated by the previous steps. * If the view output doesn't have the same number of columns neither with the child output, nor diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4af813456b79..64da9bb9cdec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -51,7 +51,7 @@ trait InvokeLike extends Expression with NonSQLExpression { * * - generate codes for argument. * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. - * - avoid some of nullabilty checking which are not needed because the expression is not + * - avoid some of nullability checking which are not needed because the expression is not * nullable. * - when needNullCheck == true, short circuit if we found one of arguments is null because * preparing rest of arguments can be skipped in the case. @@ -193,7 +193,8 @@ case class StaticInvoke( * @param targetObject An expression that will return the object to call the method on. * @param functionName The name of the method to call. * @param dataType The expected return type of the function. - * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param arguments An optional list of expressions, whose evaluation will be passed to the + * function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. * @param returnNullable When false, indicating the invoked method will always return diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala index 10479630f3f9..30e3bc9fb577 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch /** - * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]]. + * Unit test suite for the count-min sketch SQL aggregate function [[CountMinSketchAgg]]. */ class CountMinSketchAggSuite extends SparkFunSuite { private val childExpression = BoundReference(0, IntegerType, nullable = true) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java index b5e3e44cd633..53ffa95ae0f4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java @@ -41,7 +41,7 @@ public interface MicroBatchWriteSupport extends BaseStreamingSink { * @param queryId A unique string for the writing query. It's possible that there are many writing * queries running at the same time, and the returned {@link DataSourceV2Writer} * can use this id to distinguish itself from others. - * @param epochId The uniquenumeric ID of the batch within this writing query. This is an + * @param epochId The unique numeric ID of the batch within this writing query. This is an * incrementing counter representing a consistent set of data; the same batch may * be started multiple times in failure recovery scenarios, but it will always * contain the same records. diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index 594e747a8d3a..b13850c30149 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -32,7 +32,7 @@ stroke-width: 1px; } -/* Hightlight the SparkPlan node name */ +/* Highlight the SparkPlan node name */ #plan-viz-graph svg text :first-child { font-weight: bold; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index d3874b58bc80..023e12788829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -77,7 +77,7 @@ trait FileFormat { } /** - * Returns whether a file with `path` could be splitted or not. + * Returns whether a file with `path` could be split or not. */ def isSplitable( sparkSession: SparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index b64d71bb4eef..a585cbed2551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -150,7 +150,7 @@ private[csv] object CSVInferSchema { if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - // We keep this for backwords competibility. + // We keep this for backwards compatibility. TimestampType } else { tryParseBoolean(field, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index d98cf852a1b4..1465346eb802 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -368,7 +368,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The minimum key private var minKey = Long.MaxValue - // The maxinum key + // The maximum key private var maxKey = Long.MinValue // The array to store the key and offset of UnsafeRow in the page. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 217e98a6419e..4aba76cad367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -203,7 +203,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already - * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]]. + * loaded. This is class is a modified version of [[ZippedPartitionsRDD2]]. */ class StateStoreAwareZipPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag]( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2295b8dd5fe3..d8adbe7bee13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -175,7 +175,7 @@ class SQLAppStatusListener( // Check the execution again for whether the aggregated metrics data has been calculated. // This can happen if the UI is requesting this data, and the onExecutionEnd handler is - // running at the same time. The metrics calculcated for the UI can be innacurate in that + // running at the same time. The metrics calculated for the UI can be innacurate in that // case, since the onExecutionEnd handler will clean up tracked stage metrics. if (exec.metricsValues != null) { exec.metricsValues diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 058c38c8cb8f..1e076207bc60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -86,7 +86,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable { def bufferEncoder: Encoder[BUF] /** - * Specifies the `Encoder` for the final ouput value type. + * Specifies the `Encoder` for the final output value type. * @since 2.0.0 */ def outputEncoder: Encoder[OUT] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index cedc1dce4a70..0dcb666e2c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -152,7 +152,7 @@ class StreamingQueryProgress private[sql]( * @param endOffset The ending offset for data being read. * @param numInputRows The number of records read from this source. * @param inputRowsPerSecond The rate at which data is arriving from this source. - * @param processedRowsPerSecond The rate at which data from this source is being procressed by + * @param processedRowsPerSecond The rate at which data from this source is being processed by * Spark. * @since 2.1.0 */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index 447a71d284fb..288f5e7426c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -47,7 +47,7 @@ public MyDoubleAvg() { _inputDataType = DataTypes.createStructType(inputFields); // The buffer has two values, bufferSum for storing the current sum and - // bufferCount for storing the number of non-null input values that have been contribuetd + // bufferCount for storing the number of non-null input values that have been contributed // to the current sum. List bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql index 58866f4b1811..6de22b8b7c3d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql @@ -32,7 +32,7 @@ SELECT 1.1 - '2.2' FROM t; SELECT 1.1 * '2.2' FROM t; SELECT 4.4 / '2.2' FROM t; --- concatentation +-- concatenation SELECT '$' || cast(1 as smallint) || '$' FROM t; SELECT '$' || 1 || '$' FROM t; SELECT '$' || cast(1 as bigint) || '$' FROM t; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 83018f95aa55..12eaf6341508 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -92,7 +92,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext test("deriveCompactInterval") { // latestCompactBatchId(4) + 1 <= default(5) - // then use latestestCompactBatchId + 1 === 5 + // then use latestCompactBatchId + 1 === 5 assert(5 === deriveCompactInterval(5, 4)) // First divisor of 10 greater than 4 === 5 assert(5 === deriveCompactInterval(4, 9)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala index bf43de597a7a..2cb48281b30a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationP import org.apache.spark.sql.types._ -// Note that the package name is intendedly mismatched in order to resemble external data sources +// Note that the package name is intentionally mismatched in order to resemble external data sources // and test the detection for them. class FakeExternalSourceOne extends RelationProvider with DataSourceRegister { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 2a2552211857..8c4e1fd00b0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -81,7 +81,7 @@ class FileStreamSinkSuite extends StreamTest { .start(outputDir) try { - // The output is partitoned by "value", so the value will appear in the file path. + // The output is partitioned by "value", so the value will appear in the file path. // This is to test if we handle spaces in the path correctly. inputData.addData("hello world") failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b4248b74f50a..904f9f2ad0b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -113,7 +113,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with if (thread.isAlive) { thread.interrupt() // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and + // is not interruptible. There is not much point to wait for the thread to terminate, and // we rather let the JVM terminate the thread on exit. fail( s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 9e9894803ce2..11afe1af3280 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -50,7 +50,7 @@ private[hive] object HiveShim { val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + * This function in hive-0.13 become private, but we have to do this to work around hive bug */ private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") diff --git a/sql/hive/src/test/resources/data/conf/hive-log4j.properties b/sql/hive/src/test/resources/data/conf/hive-log4j.properties index 6a042472adb9..83fd03a99bc3 100644 --- a/sql/hive/src/test/resources/data/conf/hive-log4j.properties +++ b/sql/hive/src/test/resources/data/conf/hive-log4j.properties @@ -32,7 +32,7 @@ log4j.threshhold=WARN log4j.appender.DRFA=org.apache.log4j.DailyRollingFileAppender log4j.appender.DRFA.File=${hive.log.dir}/${hive.log.file} -# Rollver at midnight +# Roll over at midnight log4j.appender.DRFA.DatePattern=.yyyy-MM-dd # 30-day backup diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 15d3c7e54b8d..8da5a5f8193c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -162,7 +162,7 @@ private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, mappingFunction, batchTime, timeoutThresholdTime, - removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + removeTimedoutData = doFullScan // remove timed-out data only when full scan is enabled ) Iterator(newRecord) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3a21cfae5ac2..89524cd84ff3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -364,7 +364,7 @@ private[streaming] object OpenHashMapBasedStateMap { } /** - * Internal class to represent a marker the demarkate the end of all state data in the + * Internal class to represent a marker that demarcates the end of all state data in the * serialized bytes. */ class LimitMarker(val num: Int) extends Serializable From 1c9f95cb771ac78775a77edd1abfeb2d8ae2a124 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 2 Jan 2018 07:13:27 +0900 Subject: [PATCH 065/100] [SPARK-22530][PYTHON][SQL] Adding Arrow support for ArrayType ## What changes were proposed in this pull request? This change adds `ArrayType` support for working with Arrow in pyspark when creating a DataFrame, calling `toPandas()`, and using vectorized `pandas_udf`. ## How was this patch tested? Added new Python unit tests using Array data. Author: Bryan Cutler Closes #20114 from BryanCutler/arrow-ArrayType-support-SPARK-22530. --- python/pyspark/sql/tests.py | 47 ++++++++++++++++++- python/pyspark/sql/types.py | 4 ++ .../vectorized/ArrowColumnVector.java | 13 ++++- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1c34c897eecb..67bdb3d72d93 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3372,6 +3372,31 @@ def test_schema_conversion_roundtrip(self): schema_rt = from_arrow_schema(arrow_schema) self.assertEquals(self.schema, schema_rt) + def test_createDataFrame_with_array_type(self): + import pandas as pd + pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) + df, df_arrow = self._createDataFrame_toggle(pdf) + result = df.collect() + result_arrow = df_arrow.collect() + expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_toPandas_with_array_type(self): + expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] + array_schema = StructType([StructField("a", ArrayType(IntegerType())), + StructField("b", ArrayType(StringType()))]) + df = self.spark.createDataFrame(expected, schema=array_schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): @@ -3651,6 +3676,24 @@ def test_vectorized_udf_datatype_string(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), ([3, 4],)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + + def test_vectorized_udf_null_array(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_complex(self): from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( @@ -3705,7 +3748,7 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType())) + f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.select(f(col('id'))).collect() @@ -4009,7 +4052,7 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, - 'id long, v array', + 'id long, v map', PandasUDFType.GROUP_MAP ) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 02b245713b6a..146e673ae975 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1625,6 +1625,8 @@ def to_arrow_type(dt): elif type(dt) == TimestampType: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') + elif type(dt) == ArrayType: + arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type @@ -1665,6 +1667,8 @@ def from_arrow_type(at): spark_type = DateType() elif types.is_timestamp(at): spark_type = TimestampType() + elif types.is_list(at): + spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 528f66f342dc..af5673e26a50 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -326,7 +326,8 @@ private abstract static class ArrowVectorAccessor { this.vector = vector; } - final boolean isNullAt(int rowId) { + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { return vector.isNull(rowId); } @@ -589,6 +590,16 @@ private static class ArrayAccessor extends ArrowVectorAccessor { this.accessor = vector; } + @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + @Override final int getArrayLength(int rowId) { return accessor.getInnerValueCountAt(rowId); From e734a4b9c23463a7fea61011027a822bc9e11c98 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Jan 2018 07:20:05 +0900 Subject: [PATCH 066/100] [SPARK-21893][SPARK-22142][TESTS][FOLLOWUP] Enables PySpark tests for Flume and Kafka in Jenkins ## What changes were proposed in this pull request? This PR proposes to enable PySpark tests for Flume and Kafka in Jenkins by explicitly setting the environment variables in `modules.py`. Seems we are not taking the dependencies into account when calculating environment variables: https://github.com/apache/spark/blob/3a07eff5af601511e97a05e6fea0e3d48f74c4f0/dev/run-tests.py#L554-L561 ## How was this patch tested? Manual tests with Jenkins in https://github.com/apache/spark/pull/20126. **Before** - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/85559/consoleFull ``` [info] Setup the following environment variables for tests: ... ``` **After** - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/85560/consoleFull ``` [info] Setup the following environment variables for tests: ENABLE_KAFKA_0_8_TESTS=1 ENABLE_FLUME_TESTS=1 ... ``` Author: hyukjinkwon Closes #20128 from HyukjinKwon/SPARK-21893. --- dev/sparktestsupport/modules.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 44f990ec8a5a..f834563da9dd 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -418,6 +418,10 @@ def __hash__(self): source_file_regexes=[ "python/pyspark/streaming" ], + environ={ + "ENABLE_FLUME_TESTS": "1", + "ENABLE_KAFKA_0_8_TESTS": "1" + }, python_test_goals=[ "pyspark.streaming.util", "pyspark.streaming.tests", From e0c090f227e9b64e595b47d4d1f96f8a2fff5bf7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Jan 2018 09:19:18 +0800 Subject: [PATCH 067/100] [SPARK-22932][SQL] Refactor AnalysisContext ## What changes were proposed in this pull request? Add a `reset` function to ensure the state in `AnalysisContext ` is per-query. ## How was this patch tested? The existing test cases Author: gatorsmile Closes #20127 from gatorsmile/refactorAnalysisContext. --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee..35b35110e491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -176,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -600,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -1269,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1406,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] From a6fc300e91273230e7134ac6db95ccb4436c6f8f Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 2 Jan 2018 23:30:38 +0800 Subject: [PATCH 068/100] [SPARK-22897][CORE] Expose stageAttemptId in TaskContext ## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE Closes #20082 from advancedxy/SPARK-22897. --- .../scala/org/apache/spark/TaskContext.scala | 9 +++++- .../org/apache/spark/TaskContextImpl.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/JavaTaskContextCompileCheck.java | 2 ++ .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++-- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 29 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 +- 13 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4f..69739745aa6c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb0..cccd3ea457ba 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a0..f536fc2a5f0a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1..f8e233a05a44 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0..ced5a06516f7 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc24..dcf89e4f75ac 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085..aa9c36c0aaac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f1..9c0699bc981f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813e..3b452f35c5ec 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae799..3e31d22e15c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d..6af9f8b77f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9..a3ae93810aa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bed..3fad7dfddadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From 247a08939d58405aef39b2a4e7773aa45474ad12 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 3 Jan 2018 21:40:51 +0800 Subject: [PATCH 069/100] [SPARK-22938] Assert that SQLConf.get is accessed only on the driver. ## What changes were proposed in this pull request? Assert if code tries to access SQLConf.get on executor. This can lead to hard to detect bugs, where the executor will read fallbackConf, falling back to default config values, ignoring potentially changed non-default configs. If a config is to be passed to executor code, it needs to be read on the driver, and passed explicitly. ## How was this patch tested? Check in existing tests. Author: Juliusz Sompolski Closes #20136 from juliuszsompolski/SPARK-22938. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af5..80cdc61484c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -1087,6 +1089,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From 1a87a1609c4d2c9027a2cf669ea3337b89f61fb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 3 Jan 2018 22:09:30 +0800 Subject: [PATCH 070/100] [SPARK-22934][SQL] Make optional clauses order insensitive for CREATE TABLE SQL statement ## What changes were proposed in this pull request? Currently, our CREATE TABLE syntax require the EXACT order of clauses. It is pretty hard to remember the exact order. Thus, this PR is to make optional clauses order insensitive for `CREATE TABLE` SQL statement. ``` CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1 col_type1 [COMMENT col_comment1], ...)] USING datasource [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` The same idea is also applicable to Create Hive Table. ``` CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1[:] col_type1 [COMMENT col_comment1], ...)] [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20133 from gatorsmile/createDataSourceTableDDL. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 24 +- .../sql/catalyst/parser/ParserUtils.scala | 9 + .../spark/sql/execution/SparkSqlParser.scala | 81 +++++-- .../execution/command/DDLParserSuite.scala | 220 ++++++++++++++---- .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 13 +- .../sql/hive/execution/SQLQuerySuite.scala | 124 +++++----- 7 files changed, 335 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d5..6daf01d98426 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e..89347f4b1f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972..d3cfd2a1ffbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b60..2b1aea08b122 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9c..591510c1d828 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38..65be24441867 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae94584..47adc77a52d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } From a66fe36cee9363b01ee70e469f1c968f633c5713 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Jan 2018 22:18:13 +0800 Subject: [PATCH 071/100] [SPARK-20236][SQL] dynamic partition overwrite ## What changes were proposed in this pull request? When overwriting a partitioned table with dynamic partition columns, the behavior is different between data source and hive tables. data source table: delete all partition directories that match the static partition values provided in the insert statement. hive table: only delete partition directories which have data written into it This PR adds a new config to make users be able to choose hive's behavior. ## How was this patch tested? new tests Author: Wenchen Fan Closes #18714 from cloud-fan/overwrite-partition. --- .../internal/io/FileCommitProtocol.scala | 25 ++++-- .../io/HadoopMapReduceCommitProtocol.scala | 75 ++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 21 +++++ .../InsertIntoHadoopFsRelationCommand.scala | 20 ++++- .../SQLHadoopMapReduceCommitProtocol.scala | 10 ++- .../spark/sql/sources/InsertSuite.scala | 78 +++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af453..6d0059b6a027 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -28,8 +28,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,10 +140,22 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9..6d20ef1f98a3 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + fs.delete(finalPartPath, true) + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +207,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80cdc61484c0..5d6edf6b8abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1068,6 +1068,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1394,6 +1412,9 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942..dd7ef0d15c14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b..39c594a9bc61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f4594..fef01c860db6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } From 9a2b65a3c0c36316aae0a53aa0f61c5044c2ceff Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 3 Jan 2018 11:31:32 -0600 Subject: [PATCH 072/100] [SPARK-22896] Improvement in String interpolation ## What changes were proposed in this pull request? * String interpolation in ml pipeline example has been corrected as per scala standard. ## How was this patch tested? * manually tested. Author: chetkhatri Closes #20070 from chetkhatri/mllib-chetan-contrib. --- .../spark/examples/ml/JavaQuantileDiscretizerExample.java | 2 +- .../apache/spark/examples/SimpleSkewedGroupByTest.scala | 4 ---- .../org/apache/spark/examples/graphx/Analytics.scala | 6 ++++-- .../org/apache/spark/examples/graphx/SynthBenchmark.scala | 6 +++--- .../apache/spark/examples/ml/ChiSquareTestExample.scala | 6 +++--- .../org/apache/spark/examples/ml/CorrelationExample.scala | 4 ++-- .../org/apache/spark/examples/ml/DataFrameExample.scala | 4 ++-- .../examples/ml/DecisionTreeClassificationExample.scala | 4 ++-- .../spark/examples/ml/DecisionTreeRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/DeveloperApiExample.scala | 6 +++--- .../examples/ml/EstimatorTransformerParamExample.scala | 6 +++--- .../ml/GradientBoostedTreeClassifierExample.scala | 4 ++-- .../examples/ml/GradientBoostedTreeRegressorExample.scala | 4 ++-- ...ulticlassLogisticRegressionWithElasticNetExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 2 +- .../org/apache/spark/examples/ml/NaiveBayesExample.scala | 2 +- .../spark/examples/ml/QuantileDiscretizerExample.scala | 4 ++-- .../spark/examples/ml/RandomForestClassifierExample.scala | 4 ++-- .../spark/examples/ml/RandomForestRegressorExample.scala | 4 ++-- .../apache/spark/examples/ml/VectorIndexerExample.scala | 4 ++-- .../spark/examples/mllib/AssociationRulesExample.scala | 6 +++--- .../mllib/BinaryClassificationMetricsExample.scala | 4 ++-- .../mllib/DecisionTreeClassificationExample.scala | 4 ++-- .../examples/mllib/DecisionTreeRegressionExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/FPGrowthExample.scala | 2 +- .../mllib/GradientBoostingClassificationExample.scala | 4 ++-- .../mllib/GradientBoostingRegressionExample.scala | 4 ++-- .../spark/examples/mllib/HypothesisTestingExample.scala | 2 +- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/KMeansExample.scala | 2 +- .../org/apache/spark/examples/mllib/LBFGSExample.scala | 2 +- .../examples/mllib/LatentDirichletAllocationExample.scala | 8 +++++--- .../examples/mllib/LinearRegressionWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/PCAExample.scala | 4 ++-- .../spark/examples/mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/examples/mllib/PrefixSpanExample.scala | 4 ++-- .../mllib/RandomForestClassificationExample.scala | 4 ++-- .../examples/mllib/RandomForestRegressionExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../apache/spark/examples/mllib/SVMWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/SimpleFPGrowth.scala | 8 +++----- .../spark/examples/mllib/StratifiedSamplingExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/TallSkinnyPCA.scala | 2 +- .../org/apache/spark/examples/mllib/TallSkinnySVD.scala | 2 +- .../apache/spark/examples/streaming/CustomReceiver.scala | 6 +++--- .../apache/spark/examples/streaming/RawNetworkGrep.scala | 2 +- .../examples/streaming/RecoverableNetworkWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 49 files changed, 94 insertions(+), 96 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac62110..43cc30c1a899 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d9..2332a661f26a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc..815404d1218b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742a..57b2edf99220 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce5..5146fd031646 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb0..d7f1fc8ed74d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf1696..ee4469faab3a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933e..276cedab11ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0..aaaecaea4708 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e9..2dc11b07d88e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a692..e5d91f132a3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b81..ef78c0a1145e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d..3feb2343f6a8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353..3e61dbe628c2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,7 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8..646f46a92506 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66..50c70c626b12 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb7..0fe16fb6dfa9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce428..6265f8390252 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef..2679fcb353a8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b9..96bb8ea2338a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e..a07535bb5a38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff..c6312d71cc91 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777c..c2f89b72c9a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5..1ecf6426e1f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc12775..f724ee1030f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a3..3c56e1941aec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839..c288bf29bf25 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c0..add171973953 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04..a10d6f0dda88 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f37..b0a6f1671a89 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa09838..123782fa6b9c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f..d25962c5500e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d39961809448..449b725d1d17 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba..eff2393cc3ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb1..96deafd469bc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c433657..8b789277774a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733e..246e71de2561 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4..770e30276bc3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e7..0bb2b8c8c2b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3fa..285e2ce51263 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba0..694c3bb18b04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef6069..3d41bef0af88 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5..071d341b8161 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e..8ae6de16d80e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b120..25c7bf287197 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b..437ccf0898d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c042732113..f018f3a26d2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2..2108bc63edea 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32..b8e7c7e9e915 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() From b297029130735316e1ac1144dee44761a12bfba7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 07:28:53 +0800 Subject: [PATCH 073/100] [SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request? move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20116 from cloud-fan/column-vector. --- .../VectorizedParquetRecordReader.java | 7 ++- .../vectorized/ColumnVectorUtils.java | 2 + .../vectorized/MutableColumnarRow.java | 4 ++ .../vectorized/WritableColumnVector.java | 7 ++- .../vectorized/ArrowColumnVector.java | 62 +------------------ .../vectorized/ColumnVector.java | 31 ++++++---- .../vectorized/ColumnarArray.java | 7 +-- .../vectorized/ColumnarBatch.java | 34 +++------- .../vectorized/ColumnarRow.java | 7 +-- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 20 files changed, 63 insertions(+), 125 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ArrowColumnVector.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnVector.java (79%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarArray.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarBatch.java (73%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarRow.java (96%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411..cd745b1f0e4e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e..b5cbe8e2839b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe..70057a9def6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e1..d2ae32b06f83 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a50..708333213f3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd..d1196e1299fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec..0d89a52e7a4f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa67972..9ae1c6d9993f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c..3c6656dec77c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,8 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { // The data for this row. @@ -34,7 +33,7 @@ public final class ColumnarRow extends InternalRow { private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292b..5617046e1396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a..ce3c68810f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d6..0cf9b53ce1d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc41243026..bcd1aa0890ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b1285..933b9753faa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f..835ce9846247 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed353565..dc5ba96e69ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92..c42bc60a59d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a65..7304803a092c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f..944240f3bade 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d..675f06b31b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) From 7d045c5f00e2c7c67011830e2169a4e130c3ace8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 13:14:52 +0800 Subject: [PATCH 074/100] [SPARK-22944][SQL] improve FoldablePropagation ## What changes were proposed in this pull request? `FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes. Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20139 from cloud-fan/foldable. --- .../sql/catalyst/optimizer/expressions.scala | 65 +++++++++++-------- .../optimizer/FoldablePropagationSuite.scala | 23 ++++++- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc3..1c0b7bd80680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a..c28844642aed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } From df95a908baf78800556636a76d58bba9b3dd943f Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 Jan 2018 21:43:14 -0800 Subject: [PATCH 075/100] [SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, trigger, partitionBy ## What changes were proposed in this pull request? R Structured Streaming API for withWatermark, trigger, partitionBy ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20129 from felixcheung/rwater. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 96 +++++++++++++++- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/generics.R | 6 + R/pkg/tests/fulltests/test_streaming.R | 107 ++++++++++++++++++ python/pyspark/sql/streaming.py | 4 + .../sql/execution/streaming/Triggers.scala | 2 +- 7 files changed, 214 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47..c51eb0f39c4b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -179,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb..9956f7eda91e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3661,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3707,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3725,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3967,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb..9d0a2d5e074e 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5..e0dde3339fab 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f51..a354d50c6b54 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7a..24ae3776a217 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -793,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c0..19e3e55cb282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental From 9fa703e89318922393bae03c0db4575f4f4b4c56 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 4 Jan 2018 19:10:10 +0800 Subject: [PATCH 076/100] [SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent ## What changes were proposed in this pull request? ChildFirstClassLoader's parent is set to null, so we can't get jars from its parent. This will cause ClassNotFoundException during HiveClient initialization with builtin hive jars, where we may should use spark context loader instead. ## How was this patch tested? add new ut cc cloud-fan gatorsmile Author: Kent Yao Closes #20145 from yaooqinn/SPARK-22950. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 4 +++- .../spark/sql/hive/HiveUtilsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd..c7717d70c996 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a6844..8697d47e89e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } From d5861aba9d80ca15ad3f22793b79822e470d6913 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 19:17:22 +0800 Subject: [PATCH 077/100] [SPARK-22945][SQL] add java UDF APIs in the functions object ## What changes were proposed in this pull request? Currently Scala users can use UDF like ``` val foo = udf((i: Int) => Math.random() + i).asNondeterministic df.select(foo('a)) ``` Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20141 from cloud-fan/udf. --- .../apache/spark/sql/UDFRegistration.scala | 90 ++--- .../sql/expressions/UserDefinedFunction.scala | 1 + .../org/apache/spark/sql/functions.scala | 313 ++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 11 + .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +- 5 files changed, 315 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e4..f94baef39dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f83052..40a058d2cadd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01de..0d11682d80a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3254,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84..4f8a31f18572 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7..db37be68e42e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { From 5aadbc929cb194e06dbd3bab054a161569289af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Jan 2018 21:07:31 +0800 Subject: [PATCH 078/100] [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction ## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile Closes #20137 from gatorsmile/registerFunction. --- python/pyspark/sql/catalog.py | 27 +++++++++++++++---- python/pyspark/sql/context.py | 16 +++++++++--- python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++---------- python/pyspark/sql/udf.py | 21 ++++++++++----- 4 files changed, 84 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0..156603128d06 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef..b8d86cc098e9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93..6dc767f9ec46 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,41 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +470,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -567,7 +593,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +600,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1299,7 +1322,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c..5e75eb654533 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -130,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,7 +168,8 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() return wrapper @@ -172,5 +179,5 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + self.deterministic = False return self From 6f68316e98fad72b171df422566e1fc9a7bbfcde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:15:10 +0800 Subject: [PATCH 079/100] [SPARK-22771][SQL] Add a missing return statement in Concat.checkInputDataTypes ## What changes were proposed in this pull request? This pr is a follow-up to fix a bug left in #19977. ## How was this patch tested? Added tests in `StringExpressionsSuite`. Author: Takeshi Yamamuro Closes #20149 from maropu/SPARK-22771-FOLLOWUP. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961..41dc762154a4 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e2..97ddbeba2c5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From 93f92c0ed7442a4382e97254307309977ff676f8 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 4 Jan 2018 11:39:42 -0800 Subject: [PATCH 080/100] [SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for external shuffle service ## What changes were proposed in this pull request? This PR is the second attempt of #18684 , NIO's Files API doesn't override `skip` method for `InputStream`, so it will bring in performance issue (mentioned in #20119). But using `FileInputStream`/`FileOutputStream` will also bring in memory issue (https://dzone.com/articles/fileinputstream-fileoutputstream-considered-harmful), which is severe for long running external shuffle service. So here in this proposal, only fixing the external shuffle service related code. ## How was this patch tested? Existing tests. Author: jerryshao Closes #20144 from jerryshao/SPARK-21475-v2. --- .../apache/spark/network/buffer/FileSegmentManagedBuffer.java | 3 ++- .../apache/spark/network/shuffle/ShuffleIndexInformation.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c346..8b8f9892847c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b7..386738ece51a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { From d2cddc88eac32f26b18ec26bb59e85c6f09a8c88 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:19:00 -0600 Subject: [PATCH 081/100] [SPARK-22850][CORE] Ensure queued events are delivered to all event queues. The code in LiveListenerBus was queueing events before start in the queues themselves; so in situations like the following: bus.post(someEvent) bus.addToEventLogQueue(listener) bus.start() "someEvent" would not be delivered to "listener" if that was the first listener in the queue, because the queue wouldn't exist when the event was posted. This change buffers the events before starting the bus in the bus itself, so that they can be delivered to all registered queues when the bus is started. Also tweaked the unit tests to cover the behavior above. Author: Marcelo Vanzin Closes #20039 from vanzin/SPARK-22850. --- .../spark/scheduler/LiveListenerBus.scala | 45 ++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 21 +++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b102..ba6387a8f08a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f..da6ecb82c7e4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: From 95f9659abe8845f9f3f42fd7ababd79e55c52489 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 15:00:09 -0800 Subject: [PATCH 082/100] [SPARK-22948][K8S] Move SparkPodInitContainer to correct package. Author: Marcelo Vanzin Closes #20156 from vanzin/SPARK-22948. --- dev/sparktestsupport/modules.py | 2 +- .../spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala | 2 +- .../deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala | 2 +- .../docker/src/main/dockerfiles/init-container/Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala (99%) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala (98%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dd..7164180a6a7b 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -539,7 +539,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala index 4a4b628aedbb..c0f08786b76a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.concurrent.TimeUnit diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala index 6c557ec4a7c9..e0f29ecd0fb5 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.UUID diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 055493188fcb..047056ab2633 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -21,4 +21,4 @@ FROM spark-base # command should be invoked from the top level directory of the Spark distribution. E.g.: # docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] From e288fc87a027ec1e1a21401d1f151df20dbfecf3 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 15:35:20 -0800 Subject: [PATCH 083/100] [SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-container is used ## What changes were proposed in this pull request? User-specified secrets are mounted into both the main container and init-container (when it is used) in a Spark driver/executor pod, using the `MountSecretsBootstrap`. Because `MountSecretsBootstrap` always adds new secret volumes for the secrets to the pod, the same secret volumes get added twice, one when mounting the secrets to the main container, and the other when mounting the secrets to the init-container. This PR fixes the issue by separating `MountSecretsBootstrap.mountSecrets` out into two methods: `addSecretVolumes` for adding secret volumes to a pod and `mountSecrets` for mounting secret volumes to a container, respectively. `addSecretVolumes` is only called once for each pod, whereas `mountSecrets` is called individually for the main container and the init-container (if it is used). Ref: https://github.com/apache-spark-on-k8s/spark/issues/594. ## How was this patch tested? Unit tested and manually tested. vanzin This replaces https://github.com/apache/spark/pull/20148. hex108 foxish kimoonkim Author: Yinan Li Closes #20159 from liyinan926/master. --- .../deploy/k8s/MountSecretsBootstrap.scala | 30 ++++++++++++------- .../k8s/submit/DriverConfigOrchestrator.scala | 16 +++++----- .../steps/BasicDriverConfigurationStep.scala | 2 +- .../submit/steps/DriverMountSecretsStep.scala | 4 +-- .../InitContainerMountSecretsStep.scala | 11 +++---- .../cluster/k8s/ExecutorPodFactory.scala | 6 ++-- .../k8s/{submit => }/SecretVolumeUtils.scala | 18 +++++------ .../BasicDriverConfigurationStepSuite.scala | 4 +-- .../steps/DriverMountSecretsStepSuite.scala | 4 +-- .../InitContainerMountSecretsStepSuite.scala | 7 +---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 14 +++++---- 11 files changed, 61 insertions(+), 55 deletions(-) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit => }/SecretVolumeUtils.scala (71%) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala index 8286546ce064..c35e7db51d40 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -24,26 +24,36 @@ import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBui private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. * * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. + * @return the updated pod with the secret volumes added. */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + def addSecretVolumes(pod: Pod): Pod = { var podBuilder = new PodBuilder(pod) secretNamesToMountPaths.keys.foreach { name => podBuilder = podBuilder .editOrNewSpec() .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() .endSpec() } + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { var containerBuilder = new ContainerBuilder(container) secretNamesToMountPaths.foreach { case (name, path) => containerBuilder = containerBuilder @@ -53,7 +63,7 @@ private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, .endVolumeMount() } - (podBuilder.build(), containerBuilder.build()) + containerBuilder.build() } private def secretVolumeName(secretName: String): String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 00c9c4ee4917..c9cc300d6556 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -127,6 +127,12 @@ private[spark] class DriverConfigOrchestrator( Nil } + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { val orchestrator = new InitContainerConfigOrchestrator( sparkJars, @@ -147,19 +153,13 @@ private[spark] class DriverConfigOrchestrator( Nil } - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep + mountSecretsStep ++ + initContainerBootstrapStep } private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b7a69a7dfd47..eca46b84c606 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -119,7 +119,7 @@ private[spark] class BasicDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala index f872e0f4b65d..91e9a9f21133 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -28,8 +28,8 @@ private[spark] class DriverMountSecretsStep( bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) driverSpec.copy( driverPod = pod, driverContainer = container diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala index c0e7bb20cce8..0daa7b95e8aa 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -28,12 +28,9 @@ private[spark] class InitContainerMountSecretsStep( bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ba5d891f4c77..066d7e9f70ca 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -214,7 +214,7 @@ private[spark] class ExecutorPodFactory( val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) val (bootstrappedPod, bootstrappedContainer) = @@ -227,7 +227,9 @@ private[spark] class ExecutorPodFactory( val (pod, mayBeSecretsMountedInitContainer) = initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) val bootstrappedPod = KubernetesUtils.appendInitContainer( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded26..16780584a674 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e864c6a16eeb..8ee629ac8ddc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -33,7 +33,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -82,7 +82,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala index 9ec0cb55de5a..960d0bda1d01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DriverMountSecretsStepSuite extends SparkFunSuite { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala index eab4e1765945..7ac0bde80dfe 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps.initcontainer import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} class InitContainerMountSecretsStepSuite extends SparkFunSuite { @@ -44,12 +43,8 @@ class InitContainerMountSecretsStepSuite extends SparkFunSuite { val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => assert(SecretVolumeUtils.containerHasVolume( initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7121a802c69c..884da8aabd88 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -25,7 +25,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -165,17 +165,19 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val factory = new ExecutorPodFactory( conf, - None, + Some(secretsBootstrap), Some(initContainerBootstrap), Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) checkOwnerReferences(executor, driverPodUid) } From 0428368c2c5e135f99f62be20877bbbda43be310 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:34:56 -0800 Subject: [PATCH 084/100] [SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly. - Make it possible to build images from a git clone. - Make it easy to use minikube to test things. Also fixed what seemed like a bug: the base image wasn't getting the tag provided in the command line. Adding the tag allows users to use multiple Spark builds in the same kubernetes cluster. Tested by deploying images on minikube and running spark-submit from a dev environment; also by building the images with different tags and verifying "docker images" in minikube. Author: Marcelo Vanzin Closes #20154 from vanzin/SPARK-22960. --- docs/running-on-kubernetes.md | 9 +- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 3 +- .../main/dockerfiles/spark-base/Dockerfile | 7 +- sbin/build-push-docker-images.sh | 120 +++++++++++++++--- 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3..2d69f636472a 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -16,6 +16,9 @@ Kubernetes scheduler that has been added to Spark. you may setup a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -197,7 +200,7 @@ kubectl port-forward 4040:4040 Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -215,8 +218,8 @@ If the pod has encountered a runtime error, the status can be probed further usi kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 45fbcd9cd0de..ff5289e10c21 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0f806cf7e148..3eabb42d4d85 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 047056ab2633..e0a249e0ac71 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a8..da1d6b9e161c 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index b3137598692d..bb8806dd33f3 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,29 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + local base_image="$(image_ref spark-base 0)" + docker build --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t "$base_image" \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac From df7fc3ef3899cadd252d2837092bebe3442d6523 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 Jan 2018 10:16:34 +0800 Subject: [PATCH 085/100] [SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt ## What changes were proposed in this pull request? 32bit Int was used for row rank. That overflowed in a dataframe with more than 2B rows. ## How was this patch tested? Added test, but ignored, as it takes 4 minutes. Author: Juliusz Sompolski Closes #20152 from juliuszsompolski/SPARK-22957. --- .../aggregate/ApproximatePercentile.scala | 12 ++++++------ .../spark/sql/catalyst/util/QuantileSummaries.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed..a45854a3b514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6a..b013add9c977 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2..5169d2b5fc6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() From 52fc5c17d9d784b846149771b398e741621c0b5c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 14:02:21 +0800 Subject: [PATCH 086/100] [SPARK-22825][SQL] Fix incorrect results of Casting Array to String ## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro Closes #20024 from maropu/SPARK-22825. --- .../codegen/UTF8StringBuilder.java | 78 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 68 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 000000000000..f0f66bae245f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16d..d4fc5e0f168a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a..e3ed7171defd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade5..96bf65fce9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From cf0aa65576acbe0209c67f04c029058fd73555c1 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 4 Jan 2018 22:45:15 -0800 Subject: [PATCH 087/100] [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit ## What changes were proposed in this pull request? Avoid holding all models in memory for `TrainValidationSplit`. ## How was this patch tested? Existing tests. Author: Bago Amirbekian Closes #20143 from MrBago/trainValidMemoryFix. --- .../spark/ml/tuning/CrossValidator.scala | 4 +++- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83..a0b507d2e718 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd1847547..8826ef3271bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") From 6cff7d19f6a905fe425bd6892fe7ca014c0e696b Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 23:23:41 -0800 Subject: [PATCH 088/100] [SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode ## What changes were proposed in this pull request? We missed enabling `spark.files` and `spark.jars` in https://github.com/apache/spark/pull/19954. The result is that remote dependencies specified through `spark.files` or `spark.jars` are not included in the list of remote dependencies to be downloaded by the init-container. This PR fixes it. ## How was this patch tested? Manual tests. vanzin This replaces https://github.com/apache/spark/pull/20157. foxish Author: Yinan Li Closes #20160 from liyinan926/SPARK-22757. --- .../src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a..1e381965c52b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -584,10 +584,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, From 51c33bd0d402af9e0284c6cbc0111f926446bfba Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 5 Jan 2018 21:32:39 +0800 Subject: [PATCH 089/100] [SPARK-22961][REGRESSION] Constant columns should generate QueryPlanConstraints ## What changes were proposed in this pull request? #19201 introduced the following regression: given something like `df.withColumn("c", lit(2))`, we're no longer picking up `c === 2` as a constraint and infer filters from it when joins are involved, which may lead to noticeable performance degradation. This patch re-enables this optimization by picking up Aliases of Literals in Projection lists as constraints and making sure they're not treated as aliased columns. ## How was this patch tested? Unit test was added. Author: Adrian Ionescu Closes #20155 from adrian-ionescu/constant_constraints. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 ++ .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5..ff2a0ec58856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38de..9c0a30a47f83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec7..a0708bf7eee9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } From c0b7424ecacb56d3e7a18acc11ba3d5e7be57c43 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 5 Jan 2018 09:58:28 -0800 Subject: [PATCH 090/100] [SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on platforms that don't have wget ## What changes were proposed in this pull request? Modified HiveExternalCatalogVersionsSuite.scala to use Utils.doFetchFile to download different versions of Spark binaries rather than launching wget as an external process. On platforms that don't have wget installed, this suite fails with an error. cloud-fan : would you like to check this change? ## How was this patch tested? 1) test-only of HiveExternalCatalogVersionsSuite on several platforms. Tested bad mirror, read timeout, and redirects. 2) ./dev/run-tests Author: Bruce Robbins Closes #20147 from bersprockets/SPARK-22940-alt. --- .../HiveExternalCatalogVersionsSuite.scala | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a676..ae4aeb7b4ce4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() From 930b90a84871e2504b57ed50efa7b8bb52d3ba44 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Jan 2018 11:51:25 -0800 Subject: [PATCH 091/100] [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator ## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #20132 from jkbradley/viirya-SPARK-13030. --- .../ml/feature/OneHotEncoderEstimator.scala | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28..bd1e3426c878 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) From ea956833017fcbd8ed2288368bfa2e417a2251c5 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Fri, 5 Jan 2018 17:25:28 -0800 Subject: [PATCH 092/100] [SPARK-22914][DEPLOY] Register history.ui.port ## What changes were proposed in this pull request? Register spark.history.ui.port as a known spark conf to be used in substitution expressions even if it's not set explicitly. ## How was this patch tested? Added unit test to demonstrate the issue Author: Gera Shegalov Author: Gera Shegalov Closes #20098 from gerashegalov/gera/register-SHS-port-conf. --- .../spark/deploy/history/HistoryServer.scala | 3 +- .../apache/spark/deploy/history/config.scala | 5 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +++++--- .../deploy/yarn/ApplicationMasterSuite.scala | 43 +++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30..0ec4afad0308 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -276,7 +277,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a..efdbf672bb52 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,9 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d7263..4d5e3bb04367 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 000000000000..695a82f3583e --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} From e8af7e8aeca15a6107248f358d9514521ffdc6d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 6 Jan 2018 09:26:03 +0800 Subject: [PATCH 093/100] [SPARK-22937][SQL] SQL elt output binary for binary inputs ## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro Closes #20135 from maropu/SPARK-22937. --- docs/sql-programming-guide.md | 2 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 +++++ .../expressions/stringExpressions.scala | 46 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../catalyst/analysis/TypeCoercionSuite.scala | 54 ++++++++ .../inputs/typeCoercion/native/elt.sql | 44 +++++++ .../results/typeCoercion/native/elt.sql.out | 115 ++++++++++++++++++ 7 files changed, 281 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d2..b50f9360b866 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1783,6 +1783,8 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2..e8669c4637d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -684,6 +685,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 41dc762154a4..e004bfc6af47 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d6edf6b8abe..80b8965e084a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1052,6 +1052,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622..52a7ebdafd7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 000000000000..717616f91db0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 000000000000..b62e1b682604 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 From bf65cd3cda46d5480bfcd13110975c46ca631972 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 5 Jan 2018 17:29:27 -0800 Subject: [PATCH 094/100] [SPARK-22960][K8S] Revert use of ARG base_image in images ## What changes were proposed in this pull request? This PR reverts the `ARG base_image` before `FROM` in the images of driver, executor, and init-container, introduced in https://github.com/apache/spark/pull/20154. The reason is Docker versions before 17.06 do not support this use (`ARG` before `FROM`). ## How was this patch tested? Tested manually. vanzin foxish kimoonkim Author: Yinan Li Closes #20170 from liyinan926/master. --- .../docker/src/main/dockerfiles/driver/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/executor/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/init-container/Dockerfile | 3 +-- sbin/build-push-docker-images.sh | 8 ++++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index ff5289e10c21..45fbcd9cd0de 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 3eabb42d4d85..0f806cf7e148 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index e0a249e0ac71..047056ab2633 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index bb8806dd33f3..b9532597419a 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -60,13 +60,13 @@ function image_ref { } function build { - local base_image="$(image_ref spark-base 0)" - docker build --build-arg "spark_jars=$SPARK_JARS" \ + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ --build-arg "img_path=$IMG_PATH" \ - -t "$base_image" \ + -t spark-base \ -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } From f2dd8b923759e8771b0e5f59bfa7ae4ad7e6a339 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 6 Jan 2018 16:11:20 +0800 Subject: [PATCH 095/100] [SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs for non-deterministic cases ## What changes were proposed in this pull request? Add tests for using non deterministic UDFs in aggregate. Update pandas_udf docstring w.r.t to determinism. ## How was this patch tested? test_nondeterministic_udf_in_aggregate Author: Li Jin Closes #20142 from icexelloss/SPARK-22930-pandas-udf-deterministic. --- python/pyspark/sql/functions.py | 12 +++++++- python/pyspark/sql/tests.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b..733e32bd825b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2214,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46..689736d8e645 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -386,6 +386,7 @@ def test_udf3(self): self.assertEqual(row[0], 5) def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() @@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -3567,6 +3580,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): From be9a804f2ef77a5044d3da7d9374976daf59fc16 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 6 Jan 2018 18:07:45 +0800 Subject: [PATCH 096/100] [SPARK-22793][SQL] Memory leak in Spark Thrift Server # What changes were proposed in this pull request? 1. Start HiveThriftServer2. 2. Connect to thriftserver through beeline. 3. Close the beeline. 4. repeat step2 and step 3 for many times. we found there are many directories never be dropped under the path `hive.exec.local.scratchdir` and `hive.exec.scratchdir`, as we know the scratchdir has been added to deleteOnExit when it be created. So it means that the cache size of FileSystem `deleteOnExit` will keep increasing until JVM terminated. In addition, we use `jmap -histo:live [PID]` to printout the size of objects in HiveThriftServer2 Process, we can find the object `org.apache.spark.sql.hive.client.HiveClientImpl` and `org.apache.hadoop.hive.ql.session.SessionState` keep increasing even though we closed all the beeline connections, which may caused the leak of Memory. # How was this patch tested? manual tests This PR follw-up the https://github.com/apache/spark/pull/19989 Author: zuotingbing Closes #20029 from zuotingbing/SPARK-22793. --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e..dc92ad3b0c1a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } From 7b78041423b6ee330def2336dfd1ff9ae8469c59 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 6 Jan 2018 18:19:57 +0800 Subject: [PATCH 097/100] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? Since Hive 1.1, Hive allows users to set parquet compression codec via table-level properties parquet.compression. See the JIRA: https://issues.apache.org/jira/browse/HIVE-7858 . We do support orc.compression for ORC. Thus, for external users, it is more straightforward to support both. See the stackflow question: https://stackoverflow.com/questions/36941122/spark-sql-ignores-parquet-compression-propertie-specified-in-tblproperties In Spark side, our table-level compression conf compression was added by #11464 since Spark 2.0. We need to support both table-level conf. Users might also use session-level conf spark.sql.parquet.compression.codec. The priority rule will be like If other compression codec configuration was found through hive or parquet, the precedence would be compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo. The rule for Parquet is consistent with the ORC after the change. Changes: 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the precedence order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Add test. Author: fjh100456 Closes #20076 from fjh100456/ParquetOptionIssue. --- docs/sql-programming-guide.md | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../datasources/parquet/ParquetOptions.scala | 12 +- ...rquetCompressionCodecPrecedenceSuite.scala | 122 ++++++++++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b50f9360b866..3ccaaf4d5b1f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,8 +953,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80b8965e084a..7d1217de254a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,11 +325,13 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + + "`parquet.compression` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -366,8 +368,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + + "`orc.compress` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de54..ef67ea7d17ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 000000000000..ed8fd2b45345 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} From 993f21567a1dd33e43ef9a626e0ddfbe46f83f93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 6 Jan 2018 23:08:26 +0800 Subject: [PATCH 098/100] [SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic for wrapped UDF function ## What changes were proposed in this pull request? This PR wraps the `asNondeterministic` attribute in the wrapped UDF function to set the docstring properly. ```python from pyspark.sql.functions import udf help(udf(lambda x: x).asNondeterministic) ``` Before: ``` Help on function in module pyspark.sql.udf: lambda (END ``` After: ``` Help on function asNondeterministic in module pyspark.sql.udf: asNondeterministic() Updates UserDefinedFunction to nondeterministic. .. versionadded:: 2.3 (END) ``` ## How was this patch tested? Manually tested and a simple test was added. Author: hyukjinkwon Closes #20173 from HyukjinKwon/SPARK-22901-followup. --- python/pyspark/sql/tests.py | 1 + python/pyspark/sql/udf.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 689736d8e645..122a65b83aef 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -413,6 +413,7 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e75eb654533..5e80ab916586 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -169,8 +169,8 @@ def wrapper(*args): wrapper.returnType = self.returnType wrapper.evalType = self.evalType wrapper.deterministic = self.deterministic - wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() - + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): From 9a7048b2889bd0fd66e68a0ce3e07e466315a051 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 7 Jan 2018 00:19:21 +0800 Subject: [PATCH 099/100] [HOTFIX] Fix style checking failure ## What changes were proposed in this pull request? This PR is to fix the style checking failure. ## How was this patch tested? N/A Author: gatorsmile Closes #20175 from gatorsmile/stylefix. --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7d1217de254a..5c61f10bb71a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,10 +325,11 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + - "`parquet.compression` is specified in the table-specific options/properties, the precedence" + - "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + - "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) @@ -368,8 +369,8 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + - "`orc.compress` is specified in the table-specific options/properties, the precedence" + + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf From 18e94149992618a2b4e6f0fd3b3f4594e1745224 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 7 Jan 2018 13:42:01 +0800 Subject: [PATCH 100/100] [SPARK-22973][SQL] Fix incorrect results of Casting Map to String ## What changes were proposed in this pull request? This pr fixed the issue when casting maps into strings; ``` scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t") scala> sql("SELECT cast(a as String) FROM t").show(false) +----------------------------------------------------------------+ |a | +----------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d| +----------------------------------------------------------------+ ``` This pr modified the result into; ``` +----------------+ |a | +----------------+ |[1 -> a, 2 -> b]| +----------------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20166 from maropu/SPARK-22973. --- .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 28 ++++++ 2 files changed, 117 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4fc5e0f168a..f2de4c8e30be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e3ed7171defd..1445bb8a97d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } }
    spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo.