diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 39c872663ad4..038236fc149e 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 00c40c38cabc..a78fbb714f2b 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -180,7 +180,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), } # Save the serialization flag after we create a RRDD rdd@env$serializedMode <- serializedMode - rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") rdd@env$jrdd_val }) @@ -225,7 +225,7 @@ setMethod("cache", #' #' Persist this RDD with the specified storage level. For details of the #' supported storage levels, refer to -#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' #' @param x The RDD to persist #' @param newLevel The new storage level to be assigned @@ -382,11 +382,13 @@ setMethod("collectPartition", #' \code{collectAsMap} returns a named list as a map that contains all of the elements #' in a key-value pair RDD. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) #' collectAsMap(rdd) # list(`1` = 2, `3` = 4) #'} +# nolint end #' @rdname collect-methods #' @aliases collectAsMap,RDD-method #' @noRd @@ -442,11 +444,13 @@ setMethod("length", #' @return list of (value, count) pairs, where count is number of each unique #' value in rdd. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,3,2,1)) #' countByValue(rdd) # (1,2L), (2,2L), (3,1L) #'} +# nolint end #' @rdname countByValue #' @aliases countByValue,RDD-method #' @noRd @@ -597,11 +601,13 @@ setMethod("mapPartitionsWithIndex", #' @param x The RDD to be filtered. #' @param f A unary predicate function. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) #'} +# nolint end #' @rdname filterRDD #' @aliases filterRDD,RDD,function-method #' @noRd @@ -756,11 +762,13 @@ setMethod("foreachPartition", #' @param x The RDD to take elements from #' @param num Number of elements to take #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' take(rdd, 2L) # list(1, 2) #'} +# nolint end #' @rdname take #' @aliases take,RDD,numeric-method #' @noRd @@ -824,11 +832,13 @@ setMethod("first", #' @param x The RDD to remove duplicates from. #' @param numPartitions Number of partitions to create. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,2,3,3,3)) #' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) #'} +# nolint end #' @rdname distinct #' @aliases distinct,RDD-method #' @noRd @@ -974,11 +984,13 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", #' @param x The RDD. #' @param func The function to be applied. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3)) #' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) #'} +# nolint end #' @rdname keyBy #' @aliases keyBy,RDD #' @noRd @@ -1113,11 +1125,13 @@ setMethod("saveAsTextFile", #' @param numPartitions Number of partitions to create. #' @return An RDD where all elements are sorted. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(3, 2, 1)) #' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) #'} +# nolint end #' @rdname sortBy #' @aliases sortBy,RDD,RDD-method #' @noRd @@ -1188,11 +1202,13 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { #' @param num Number of elements to return. #' @return The first N elements from the RDD in ascending order. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) #' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) #'} +# nolint end #' @rdname takeOrdered #' @aliases takeOrdered,RDD,RDD-method #' @noRd @@ -1209,11 +1225,13 @@ setMethod("takeOrdered", #' @return The top N elements from the RDD. #' @rdname top #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) #' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) #'} +# nolint end #' @aliases top,RDD,RDD-method #' @noRd setMethod("top", @@ -1261,6 +1279,7 @@ setMethod("fold", #' @rdname aggregateRDD #' @seealso reduce #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4)) @@ -1269,6 +1288,7 @@ setMethod("fold", #' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } #' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) #'} +# nolint end #' @aliases aggregateRDD,RDD,RDD-method #' @noRd setMethod("aggregateRDD", @@ -1367,12 +1387,14 @@ setMethod("setName", #' @return An RDD with zipped items. #' @seealso zipWithIndex #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) #' collect(zipWithUniqueId(rdd)) #' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #'} +# nolint end #' @rdname zipWithUniqueId #' @aliases zipWithUniqueId,RDD #' @noRd @@ -1408,12 +1430,14 @@ setMethod("zipWithUniqueId", #' @return An RDD with zipped items. #' @seealso zipWithUniqueId #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) #' collect(zipWithIndex(rdd)) #' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) #'} +# nolint end #' @rdname zipWithIndex #' @aliases zipWithIndex,RDD #' @noRd @@ -1454,12 +1478,14 @@ setMethod("zipWithIndex", #' @return An RDD created by coalescing all elements within #' each partition into a list. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, as.list(1:4), 2L) #' collect(glom(rdd)) #' # list(list(1, 2), list(3, 4)) #'} +# nolint end #' @rdname glom #' @aliases glom,RDD #' @noRd @@ -1519,6 +1545,7 @@ setMethod("unionRDD", #' @param other Another RDD to be zipped. #' @return An RDD zipped from the two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 0:4) @@ -1526,6 +1553,7 @@ setMethod("unionRDD", #' collect(zipRDD(rdd1, rdd2)) #' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) #'} +# nolint end #' @rdname zipRDD #' @aliases zipRDD,RDD #' @noRd @@ -1557,12 +1585,14 @@ setMethod("zipRDD", #' @param other An RDD. #' @return A new RDD which is the Cartesian product of these two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:2) #' sortByKey(cartesian(rdd, rdd)) #' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #'} +# nolint end #' @rdname cartesian #' @aliases cartesian,RDD,RDD-method #' @noRd @@ -1587,6 +1617,7 @@ setMethod("cartesian", #' @param numPartitions Number of the partitions in the result RDD. #' @return An RDD with the elements from this that are not in other. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) @@ -1594,6 +1625,7 @@ setMethod("cartesian", #' collect(subtract(rdd1, rdd2)) #' # list(1, 1, 3) #'} +# nolint end #' @rdname subtract #' @aliases subtract,RDD #' @noRd @@ -1619,6 +1651,7 @@ setMethod("subtract", #' @param numPartitions The number of partitions in the result RDD. #' @return An RDD which is the intersection of these two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) @@ -1626,6 +1659,7 @@ setMethod("subtract", #' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) #' # list(1, 2, 3) #'} +# nolint end #' @rdname intersection #' @aliases intersection,RDD #' @noRd @@ -1653,6 +1687,7 @@ setMethod("intersection", #' Assumes that all the RDDs have the *same number of partitions*, but #' does *not* require them to have the same number of elements in each partition. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 @@ -1662,6 +1697,7 @@ setMethod("intersection", #' func = function(x, y, z) { list(list(x, y, z))} )) #' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #'} +# nolint end #' @rdname zipRDD #' @aliases zipPartitions,RDD #' @noRd diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index f7e56e43016e..d8a039327539 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -17,6 +17,7 @@ # Utility functions to deserialize objects from Java. +# nolint start # Type mapping from Java to R # # void -> NULL @@ -32,6 +33,8 @@ # # Array[T] -> list() # Object -> jobj +# +# nolint end readObject <- function(con) { # Read type first diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 334c11d2f89a..f7131140feaf 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -30,12 +30,14 @@ NULL #' @param key The key to look up for #' @return a list of values in this RDD for key key #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) #' rdd <- parallelize(sc, pairs) #' lookup(rdd, 1) # list(1, 3) #'} +# nolint end #' @rdname lookup #' @aliases lookup,RDD-method #' @noRd @@ -58,11 +60,13 @@ setMethod("lookup", #' @param x The RDD to count keys. #' @return list of (key, count) pairs, where count is number of each key in rdd. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) #' countByKey(rdd) # ("a", 2L), ("b", 1L) #'} +# nolint end #' @rdname countByKey #' @aliases countByKey,RDD-method #' @noRd @@ -77,11 +81,13 @@ setMethod("countByKey", #' #' @param x The RDD from which the keys of each tuple is returned. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) #' collect(keys(rdd)) # list(1, 3) #'} +# nolint end #' @rdname keys #' @aliases keys,RDD #' @noRd @@ -98,11 +104,13 @@ setMethod("keys", #' #' @param x The RDD from which the values of each tuple is returned. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) #' collect(values(rdd)) # list(2, 4) #'} +# nolint end #' @rdname values #' @aliases values,RDD #' @noRd @@ -348,6 +356,7 @@ setMethod("reduceByKey", #' @return A list of elements of type list(K, V') where V' is the merged value for each key #' @seealso reduceByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) @@ -355,6 +364,7 @@ setMethod("reduceByKey", #' reduced <- reduceByKeyLocally(rdd, "+") #' reduced # list(list(1, 6), list(1.1, 3)) #'} +# nolint end #' @rdname reduceByKeyLocally #' @aliases reduceByKeyLocally,RDD,integer-method #' @noRd @@ -412,6 +422,7 @@ setMethod("reduceByKeyLocally", #' @return An RDD where each element is list(K, C) where C is the combined type #' @seealso groupByKey, reduceByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) @@ -420,6 +431,7 @@ setMethod("reduceByKeyLocally", #' combined <- collect(parts) #' combined[[1]] # Should be a list(1, 6) #'} +# nolint end #' @rdname combineByKey #' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method #' @noRd @@ -473,6 +485,7 @@ setMethod("combineByKey", #' @return An RDD containing the aggregation result. #' @seealso foldByKey, combineByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -482,6 +495,7 @@ setMethod("combineByKey", #' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) #' # list(list(1, list(3, 2)), list(2, list(7, 2))) #'} +# nolint end #' @rdname aggregateByKey #' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method #' @noRd @@ -509,11 +523,13 @@ setMethod("aggregateByKey", #' @return An RDD containing the aggregation result. #' @seealso aggregateByKey, combineByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) #' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) #'} +# nolint end #' @rdname foldByKey #' @aliases foldByKey,RDD,ANY,ANY,integer-method #' @noRd @@ -540,12 +556,14 @@ setMethod("foldByKey", #' @return a new RDD containing all pairs of elements with matching keys in #' two input RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) #' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) #'} +# nolint end #' @rdname join-methods #' @aliases join,RDD,RDD-method #' @noRd @@ -578,6 +596,7 @@ setMethod("join", #' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) #' if no elements in rdd2 have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) @@ -585,6 +604,7 @@ setMethod("join", #' leftOuterJoin(rdd1, rdd2, 2L) #' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) #'} +# nolint end #' @rdname join-methods #' @aliases leftOuterJoin,RDD,RDD-method #' @noRd @@ -616,6 +636,7 @@ setMethod("leftOuterJoin", #' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) #' if no elements in x have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) @@ -623,6 +644,7 @@ setMethod("leftOuterJoin", #' rightOuterJoin(rdd1, rdd2, 2L) #' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) #'} +# nolint end #' @rdname join-methods #' @aliases rightOuterJoin,RDD,RDD-method #' @noRd @@ -655,6 +677,7 @@ setMethod("rightOuterJoin", #' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements #' in x/y have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) @@ -664,6 +687,7 @@ setMethod("rightOuterJoin", #' # list(2, list(NULL, 4))) #' # list(3, list(3, NULL)), #'} +# nolint end #' @rdname join-methods #' @aliases fullOuterJoin,RDD,RDD-method #' @noRd @@ -688,6 +712,7 @@ setMethod("fullOuterJoin", #' @return a new RDD containing all pairs of elements with values in a list #' in all RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) @@ -695,6 +720,7 @@ setMethod("fullOuterJoin", #' cogroup(rdd1, rdd2, numPartitions = 2L) #' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) #'} +# nolint end #' @rdname cogroup #' @aliases cogroup,RDD-method #' @noRd @@ -740,11 +766,13 @@ setMethod("cogroup", #' @param numPartitions Number of partitions to create. #' @return An RDD where all (k, v) pair elements are sorted. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) #' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) #'} +# nolint end #' @rdname sortByKey #' @aliases sortByKey,RDD,RDD-method #' @noRd @@ -805,6 +833,7 @@ setMethod("sortByKey", #' @param numPartitions Number of the partitions in the result RDD. #' @return An RDD with the pairs from x whose keys are not in other. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), @@ -813,6 +842,7 @@ setMethod("sortByKey", #' collect(subtractByKey(rdd1, rdd2)) #' # list(list("b", 4), list("b", 5)) #'} +# nolint end #' @rdname subtractByKey #' @aliases subtractByKey,RDD #' @noRd diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 17082b4e52fc..095ddb9aed2e 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -17,6 +17,7 @@ # Utility functions to serialize R objects so they can be read in Java. +# nolint start # Type mapping from R to Java # # NULL -> Void @@ -31,6 +32,7 @@ # list[T] -> Array[T], where T is one of above mentioned types # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +# nolint end getSerdeType <- function(object) { type <- class(object)[[1]] diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 7423b4f2bed1..1b3a22486e95 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -223,14 +223,14 @@ test_that("takeSample() on RDDs", { s <- takeSample(data, TRUE, 100L, seed) expect_equal(length(s), 100L) # Chance of getting all distinct elements is astronomically low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } for (seed in 4:5) { s <- takeSample(data, TRUE, 200L, seed) expect_equal(length(s), 200L) # Chance of getting all distinct elements is still quite low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } }) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index adf0b91d25fe..d3d0f8a24d01 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -176,8 +176,8 @@ test_that("partitionBy() partitions data correctly", { resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) - expected_first <- list(list(1, 100), list(2, 200)) # key < 3 - expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + expected_first <- list(list(1, 100), list(2, 200)) # key less than 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3 actual_first <- collectPartition(resultRDD, 0L) actual_second <- collectPartition(resultRDD, 1L) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7b508b860efb..9e5d0ebf6072 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -498,9 +498,11 @@ test_that("table() returns a new DataFrame", { expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + # nolint start # Test base::table is working #a <- letters[1:3] #expect_equal(class(table(a, sample(a))), "table") + # nolint end }) test_that("toRDD() returns an RRDD", { @@ -766,8 +768,10 @@ test_that("sample on a DataFrame", { sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + # nolint start # Test base::sample is working #expect_equal(length(sample(1:12)), 12) + # nolint end }) test_that("select operators", { @@ -1052,8 +1056,8 @@ test_that("string operators", { df2 <- createDataFrame(sqlContext, l2) expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) - expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") - expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint l3 <- list(list(a = "a.b.c.d")) df3 <- createDataFrame(sqlContext, l3) @@ -1259,7 +1263,7 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered6), 2) # Test stats::filter is working - #expect_true(is.ts(filter(1:100, rep(1, 3)))) + #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) test_that("join() and merge() on a DataFrame", { @@ -1659,7 +1663,7 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) # Test stats::cov is working - #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) # nolint }) test_that("freqItems() on a DataFrame", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 12df4cf4f65b..56f14a3bce61 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -95,7 +95,9 @@ test_that("cleanClosure on R functions", { # TODO(shivaram): length(ls(env)) is 4 here for some reason and `lapply` is included in `env`. # Disabling this test till we debug this. # + # nolint start # expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + # nolint end expect_true("g" %in% ls(env)) expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index bbdc9158d8e2..77e44ee0264a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -874,11 +874,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = withScope { assertNotStopped() - val job = new NewHadoopJob(hadoopConfiguration) + val job = NewHadoopJob.getInstance(hadoopConfiguration) // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updateConf = job.getConfiguration new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -923,11 +923,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { assertNotStopped() - val job = new NewHadoopJob(hadoopConfiguration) + val job = NewHadoopJob.getInstance(hadoopConfiguration) // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updateConf = job.getConfiguration new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1100,13 +1100,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() - // The call to new NewHadoopJob automatically adds security credentials to conf, + // The call to NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves - val job = new NewHadoopJob(conf) + val job = NewHadoopJob.getInstance(conf) // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updatedConf = job.getConfiguration new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1369,7 +1369,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (!fs.exists(hadoopPath)) { throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") } - val isDir = fs.getFileStatus(hadoopPath).isDir + val isDir = fs.getFileStatus(hadoopPath).isDirectory if (!isLocal && scheme == "file" && isDir) { throw new SparkException(s"addFile does not support local directories when not running " + "local mode.") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ac6eaab20d8d..dd400b8ae8a1 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -25,6 +25,7 @@ import java.util.Date import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -37,10 +38,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { +class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private val now = new Date() private val conf = new SerializableJobConf(jobConf) @@ -131,7 +129,7 @@ class SparkHadoopWriter(jobConf: JobConf) private def getJobContext(): JobContext = { if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) + jobContext = new JobContextImpl(conf.value, jID.value) } jobContext } @@ -143,6 +141,12 @@ class SparkHadoopWriter(jobConf: JobConf) taskContext } + protected def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + new TaskAttemptContextImpl(conf, attemptId) + } + private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { jobID = jobid splitID = splitid @@ -150,7 +154,7 @@ class SparkHadoopWriter(jobConf: JobConf) jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } } @@ -168,9 +172,9 @@ object SparkHadoopWriter { } val outputPath = new Path(path) val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { + if (fs == null) { throw new IllegalArgumentException("Incorrectly formatted output path") } - outputPath.makeQualified(fs) + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 59e90564b351..4bd94f13e57e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -33,9 +33,6 @@ import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.JobContext -import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} -import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.annotation.DeveloperApi @@ -76,9 +73,6 @@ class SparkHadoopUtil extends Logging { } } - @deprecated("use newConfiguration with SparkConf argument", "1.2.0") - def newConfiguration(): Configuration = newConfiguration(null) - /** * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop * subsystems. @@ -190,33 +184,6 @@ class SparkHadoopUtil extends Logging { statisticsDataClass.getDeclaredMethod(methodName) } - /** - * Using reflection to get the Configuration from JobContext/TaskAttemptContext. If we directly - * call `JobContext/TaskAttemptContext.getConfiguration`, it will generate different byte codes - * for Hadoop 1.+ and Hadoop 2.+ because JobContext/TaskAttemptContext is class in Hadoop 1.+ - * while it's interface in Hadoop 2.+. - */ - def getConfigurationFromJobContext(context: JobContext): Configuration = { - // scalastyle:off jobconfig - val method = context.getClass.getMethod("getConfiguration") - // scalastyle:on jobconfig - method.invoke(context).asInstanceOf[Configuration] - } - - /** - * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly - * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes - * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ - * while it's interface in Hadoop 2.+. - */ - def getTaskAttemptIDFromTaskAttemptContext( - context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { - // scalastyle:off jobconfig - val method = context.getClass.getMethod("getTaskAttemptID") - // scalastyle:on jobconfig - method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] - } - /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of @@ -233,11 +200,11 @@ class SparkHadoopUtil extends Logging { */ def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { - val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir) + val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDirectory) leaves ++ directories.flatMap(f => listLeafStatuses(fs, f)) } - if (baseStatus.isDir) recurse(baseStatus) else Seq(baseStatus) + if (baseStatus.isDirectory) recurse(baseStatus) else Seq(baseStatus) } def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { @@ -246,12 +213,12 @@ class SparkHadoopUtil extends Logging { def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { - val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir) + val (directories, files) = fs.listStatus(status.getPath).partition(_.isDirectory) val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus] leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir)) } - assert(baseStatus.isDir) + assert(baseStatus.isDirectory) recurse(baseStatus) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6e91d73b6e0f..c93bc8c127f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -167,7 +168,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } throw new IllegalArgumentException(msg) } - if (!fs.getFileStatus(path).isDir) { + if (!fs.getFileStatus(path).isDirectory) { throw new IllegalArgumentException( "Logging directory specified is not a directory: %s".format(logDir)) } @@ -304,7 +305,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logError("Exception encountered when attempting to update last scan time", e) lastScanTime } finally { - if (!fs.delete(path)) { + if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") } } @@ -603,7 +604,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * As of Spark 1.3, these files are consolidated into a single one that replaces the directory. * See SPARK-2261 for more detail. */ - private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() + private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDirectory /** * Returns the modification time of the given event log. If the status points at an empty @@ -648,8 +649,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Checks whether HDFS is in safe mode. The API is slightly different between hadoop 1 and 2, - * so we have to resort to ugly reflection (as usual...). + * Checks whether HDFS is in safe mode. * * Note that DistributedFileSystem is a `@LimitedPrivate` class, which for all practical reasons * makes it more public than not. @@ -663,11 +663,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - val hadoop2Class = "org.apache.hadoop.hdfs.protocol.HdfsConstants$SafeModeAction" - val actionClass: Class[_] = getClass().getClassLoader().loadClass(hadoop2Class) - val action = actionClass.getField("SAFEMODE_GET").get(null) - val method = dfs.getClass().getMethod("setSafeMode", action.getClass()) - method.invoke(dfs, action).asInstanceOf[Boolean] + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) } } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index 532850dd5771..30431a9b986b 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -23,7 +23,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil /** * Custom Input Format for reading and splitting flat binary files that contain records, @@ -36,7 +35,7 @@ private[spark] object FixedLengthBinaryInputFormat { /** Retrieves the record length property from a Hadoop configuration */ def getRecordLength(context: JobContext): Int = { - SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt } } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 67a96925da01..25596a15d93c 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -24,7 +24,6 @@ import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileSplit -import org.apache.spark.deploy.SparkHadoopUtil /** * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. @@ -83,16 +82,16 @@ private[spark] class FixedLengthBinaryRecordReader // the actual file we will be reading from val file = fileSplit.getPath // job configuration - val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration // check compression - val codec = new CompressionCodecFactory(job).getCodec(file) + val codec = new CompressionCodecFactory(conf).getCodec(file) if (codec != null) { throw new IOException("FixedLengthRecordReader does not support reading compressed files") } // get the record length recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) // get the filesystem - val fs = file.getFileSystem(job) + val fs = file.getFileSystem(conf) // open the File fileInputStream = fs.open(file) // seek to the splitStart position diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 280e7a5fe893..cb76e3c344fc 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,8 +27,6 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.deploy.SparkHadoopUtil - /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -44,7 +42,7 @@ private[spark] abstract class StreamFileInputFormat[T] */ def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } @@ -135,8 +133,7 @@ class PortableDataStream( private val confBytes = { val baos = new ByteArrayOutputStream() - SparkHadoopUtil.get.getConfigurationFromJobContext(context). - write(new DataOutputStream(baos)) + context.getConfiguration.write(new DataOutputStream(baos)) baos.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 413408723b54..fa34f1e886c7 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,7 +53,7 @@ private[spark] class WholeTextFileInputFormat */ def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index b56b2aa88a41..998c898a3fc2 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -26,8 +26,6 @@ import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.deploy.SparkHadoopUtil - /** * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface. @@ -52,8 +50,7 @@ private[spark] class WholeTextFileRecordReader( extends RecordReader[Text, Text] with Configurable { private[this] val path = split.getPath(index) - private[this] val fs = path.getFileSystem( - SparkHadoopUtil.get.getConfigurationFromJobContext(context)) + private[this] val fs = path.getFileSystem(context.getConfiguration) // True means the current file has been processed, then skip it. private[this] var processed = false diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f7298e8d5c62..249bdf5994f8 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -18,61 +18,12 @@ package org.apache.spark.mapred import java.io.IOException -import java.lang.reflect.Modifier -import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} -import org.apache.spark.util.{Utils => SparkUtils} - -private[spark] -trait SparkHadoopMapRedUtil { - def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", - "org.apache.hadoop.mapred.JobContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], - classOf[org.apache.hadoop.mapreduce.JobID]) - // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. - // Make it accessible if it's not in order to access it. - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", - "org.apache.hadoop.mapred.TaskAttemptContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) - // See above - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) - } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - SparkUtils.classForName(first) - } catch { - case e: ClassNotFoundException => - SparkUtils.classForName(second) - } - } -} object SparkHadoopMapRedUtil extends Logging { /** @@ -93,7 +44,7 @@ object SparkHadoopMapRedUtil extends Logging { jobId: Int, splitId: Int): Unit = { - val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID // Called after we have decided to commit def performCommit(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala deleted file mode 100644 index 82d807fad893..000000000000 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mapreduce - -import java.lang.{Boolean => JBoolean, Integer => JInteger} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -import org.apache.spark.util.Utils - -private[spark] -trait SparkHadoopMapReduceUtil { - def newJobContext(conf: Configuration, jobId: JobID): JobContext = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.task.JobContextImpl") - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID]) - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl") - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID]) - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") - try { - // First, attempt to use the old-style constructor that takes a boolean isMap - // (not available in YARN) - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } catch { - case exc: NoSuchMethodException => { - // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") - .asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if (isMap) "MAP" else "REDUCE") - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index aedced7408cd..2bf2337d49fe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{ Configurable, Configuration } import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.JobContextImpl + import org.apache.spark.input.StreamFileInputFormat import org.apache.spark.{ Partition, SparkContext } @@ -40,7 +42,7 @@ private[spark] class BinaryFileRDD[T]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f37c95bedc0a..920d3bf219ff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID import org.apache.hadoop.mapred.lib.CombineFileSplit +import org.apache.hadoop.mapreduce.TaskType import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -357,7 +358,7 @@ private[spark] object HadoopRDD extends Logging { def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int, conf: JobConf) { val jobID = new JobID(jobTrackerId, jobId) - val taId = new TaskAttemptID(new TaskID(jobID, true, splitId), attemptId) + val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) conf.set("mapred.tip.id", taId.getTaskID.toString) conf.set("mapred.task.id", taId.toString) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 86f38ae836b2..8b330a34c3d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -26,11 +26,11 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark.annotation.DeveloperApi import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.deploy.SparkHadoopUtil @@ -66,9 +66,7 @@ class NewHadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], @transient private val _conf: Configuration) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[(K, V)](sc, Nil) with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) @@ -109,7 +107,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(_conf, jobId) + val jobContext = new JobContextImpl(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -144,8 +142,8 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 44d195587a08..b87230142532 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskType, TaskAttemptID} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -53,10 +52,7 @@ import org.apache.spark.util.random.StratifiedSamplingUtils */ class PairRDDFunctions[K, V](self: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable -{ + extends Logging with Serializable { /** * :: Experimental :: @@ -985,11 +981,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration jobConfiguration.set("mapred.output.dir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1074,11 +1070,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1091,9 +1087,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, context.attemptNumber) - val hadoopContext = newTaskAttemptContext(config, attemptId) + val hadoopContext = new TaskAttemptContextImpl(config, attemptId) val format = outfmt.newInstance format match { case c: Configurable => c.setConf(config) @@ -1125,8 +1121,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) 1 } : Int - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) + val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) // When speculation is on and output committer class name contains "Direct", we should warn diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index fa71b8c26233..a9b3d52bbee0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -174,7 +174,8 @@ private[spark] object ReliableCheckpointRDD extends Logging { fs.create(tempOutputPath, false, bufferSize) } else { // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + fs.create(tempOutputPath, false, bufferSize, + fs.getDefaultReplication(fs.getWorkingDirectory), blockSize) } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index e3f14fe7ef0f..8e1baae796fc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.{Text, Writable} import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.WholeTextFileInputFormat @@ -44,7 +45,7 @@ private[spark] class WholeTextFileRDD( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index eaa07acc5132..68792c58c9b4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -77,14 +77,6 @@ private[spark] class EventLoggingListener( // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None - // The Hadoop APIs have changed over time, so we use reflection to figure out - // the correct method to use to flush a hadoop data stream. See SPARK-1518 - // for details. - private val hadoopFlushMethod = { - val cls = classOf[FSDataOutputStream] - scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync")) - } - private var writer: Option[PrintWriter] = None // For testing. Keep track of all JSON serialized events that have been logged. @@ -97,7 +89,7 @@ private[spark] class EventLoggingListener( * Creates the log file in the configured log directory. */ def start() { - if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) { + if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) { throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.") } @@ -147,7 +139,7 @@ private[spark] class EventLoggingListener( // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) - hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) + hadoopDataStream.foreach(_.hflush()) } if (testing) { loggedEvents += eventJson diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0e438ab4366d..8235b1024537 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ org.apache.hadoop.mapreduce.InputFormat[_, _]] - val job = new Job(conf) + val job = Job.getInstance(conf) val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 0065b1fc660b..acc24ca0fb81 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.util import java.io.File import java.util.PriorityQueue -import scala.util.{Failure, Success, Try} +import scala.util.Try import org.apache.hadoop.fs.FileSystem import org.apache.spark.Logging @@ -177,21 +177,8 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem] - .getField("SHUTDOWN_HOOK_PRIORITY") - .get(null) // static field, the value is not used - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - // scalastyle:on runtimeaddshutdownhook - } + org.apache.hadoop.util.ShutdownHookManager.get().addShutdownHook( + hookTask, FileSystem.SHUTDOWN_HOOK_PRIORITY + 30) } def runAll(): Unit = { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 11f1248c24d3..d91948e44694 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1246,7 +1246,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.newAPIHadoopFile(outputDir, org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, new Job().getConfiguration()); + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index f6a7f4375fac..2e47801aafd7 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark import java.io.{File, FileWriter} -import org.apache.spark.deploy.SparkHadoopUtil +import scala.io.Source + import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel -import scala.io.Source - import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} @@ -506,11 +505,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - val job = new Job(sc.hadoopConfiguration) + val job = Job.getInstance(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfig = job.getConfiguration jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 5cb2d4225d28..43da6fc5b547 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -67,11 +67,11 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) assert(fileSystem.exists(logPath)) val logStatus = fileSystem.getFileStatus(logPath) - assert(!logStatus.isDir) + assert(!logStatus.isDirectory) // Verify log is renamed after stop() eventLogger.stop() - assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDir) + assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDirectory) } test("Basic event logging") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 103fc19369c9..761e82e6cf1c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -23,7 +23,6 @@ import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec @@ -115,7 +114,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applications = fileSystem.listStatus(logDirPath) assert(applications != null && applications.size > 0) val eventLog = applications.sortBy(_.getModificationTime).last - assert(!eventLog.isDir) + assert(!eventLog.isDirectory) // Replay events val logData = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) diff --git a/dev/run-tests.py b/dev/run-tests.py index 706e2d141c27..23278d298c22 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -418,8 +418,9 @@ def run_python_tests(test_modules, parallelism): def run_build_tests(): - set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") - run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + # set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") + # run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + pass def run_sparkr_tests(): diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4667b289f507..47cd600bd18a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -402,7 +402,8 @@ def contains_file(self, filename): source_file_regexes=[ ".*pom.xml", "dev/test-dependencies.sh", - ] + ], + should_run_build_tests=True ) ec2 = Module( diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 984e29d1beb8..4e260e2abf04 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -22,6 +22,10 @@ set -e FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script @@ -37,7 +41,7 @@ HADOOP_PROFILES=( # resolve Spark's internal submodule dependencies. # See http://stackoverflow.com/a/3545363 for an explanation of this one-liner: -OLD_VERSION=$(mvn help:evaluate -Dexpression=project.version|grep -Ev '(^\[|Download\w+:)') +OLD_VERSION=$($MVN help:evaluate -Dexpression=project.version|grep -Ev '(^\[|Download\w+:)') TEMP_VERSION="spark-$(date +%s | tail -c6)" function reset_version { @@ -100,3 +104,5 @@ for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do exit 1 fi done + +exit 0 diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index d1b9b8d398dd..5a80985a4945 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,7 +16,6 @@ */ // scalastyle:off println - // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -80,7 +79,7 @@ object CassandraCQLTest { val InputColumnFamily = "ordercf" val OutputColumnFamily = "salecount" - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[CqlPagingInputFormat]) val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) @@ -137,4 +136,3 @@ object CassandraCQLTest { } } // scalastyle:on println -// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 1e679bfb5534..ad39a012b4ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,7 +16,6 @@ */ // scalastyle:off println -// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -59,7 +58,7 @@ object CassandraTest { val sc = new SparkContext(sparkConf) // Build the job configuration with ConfigHelper provided by Cassandra - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) val host: String = args(1) @@ -131,7 +130,6 @@ object CassandraTest { } } // scalastyle:on println -// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 4b2b3f8489fd..3acc60d6c6d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -26,11 +26,9 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.MLReader -import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e0dcd427fae2..6aacffd4f236 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,9 +24,9 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * (private[ml]) Trait for parameters for prediction (regression and classification). diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590e6..1f3325ad09ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -115,8 +115,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + val transformUDF = udf(this.createTransformFunc, outputDataType) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index b5258ff34847..d02806a6ea22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.ann -import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, - sum => Bsum} +import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV, + Vector => BV} import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} import org.apache.spark.mllib.linalg.{Vector, Vectors} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index a7c10333c0d5..521d209a8f0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala index 7ac21d7d563f..f6964054db83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml -import org.apache.spark.sql.DataFrame import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.sql.DataFrame /** * ==ML attributes== diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 45df557a8990..8186afc17a53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -26,7 +26,6 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * (private[spark]) Params for classification. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cda2bca58c50..74bf07c3f1ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index a691aa005ef5..719d1076fee8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} +import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index fdd1851ae550..865614aa5c8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d6d85ad2533a..f7d662df2fe5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 71e968497500..6e5abb29ff0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -20,15 +20,15 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * Common params for KMeans and KMeansModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 830510b1698d..af0b3e183500 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,19 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index c44db0ec595e..a921153b9474 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 63c06581482e..5b17d3483b89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 324353a96afb..33abc7c99d4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index b9e2144c0ad4..1268c87908c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6bed72164a1d..a6f878151de7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.types.DataType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index a359cb8f37ec..07a12df32035 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 9e15835429a3..61a78d73c434 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 2181119f04a5..7d2a1da990fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index c2866f5eceff..559a02526591 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature - import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} @@ -25,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ import org.apache.spark.sql.functions._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 65414ecbefbb..f8bc7e3f0c03 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index c2d514fd9629..a603b3f83320 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index d70164eaf022..c01e29af478c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 08610593fadd..42b26c8ee836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 7bf67c6325a3..39de8461dc9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.Logging -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5c43a41bee3b..2b578c2a95e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -21,8 +21,8 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c09f4d076c96..e0ca45b9a619 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.types.StructType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 318808596dc6..5d6936dce2c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 8ad7bbedaab5..8456a0e91580 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 801096fed27b..e9d1b57b91d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 5410a50bc2e4..4813d8a5b5dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index f105a983a34f..59c34cd1703a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 4d82b90bfdf2..551e75dc0a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index b798aa1fab76..14a28b8d5b51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,7 +31,7 @@ import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index aedfb48058dc..3787ca45d517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -23,18 +23,18 @@ import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index bbb1c7ac0a51..e8d361b1a2a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,18 +21,18 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 5e5850963edc..dee26337dcdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,9 +25,9 @@ import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim.WeightedLeastSquares -import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c72ef2968032..cf189e8e96f9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 11b9815ecc83..1bed542c4031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 9cfd466294b9..6507a8ad7cf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict, ImpurityStats} +import org.apache.spark.mllib.tree.model.{ImpurityStats, + InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 1ee01131d633..172ea5282056 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4a3b12d1440b..6e87302c7779 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,10 +26,10 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index b77191156f68..40ed95773e14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Abstraction for Decision Tree models. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 40f8857fc586..477675cad1a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.jackson.JsonMethods._ import org.json4s.{DefaultFormats, JObject} +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index adf06302047a..f346ea655ae5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 8897ab0825ac..553f25417241 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.Estimator import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param.{ParamMap, Param, Params} +import org.apache.spark.ml.param.{Param, ParamMap, Params} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala index bc6041b22173..6530870b83a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.api.python -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.clustering.PowerIterationClusteringModel +import org.apache.spark.rdd.RDD /** * A Wrapper of PowerIterationClusteringModel to provide helper method for Python diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f6826ddbfabf..061db56c7493 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -42,18 +42,17 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} -import org.apache.spark.mllib.stat.{ - KernelDensity, MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} -import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 0f55980481dc..55dfd973eb25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + import scala.collection.JavaConverters._ import org.apache.spark.SparkContext diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2d52abc122bf..2a7697b5a79c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel -import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 74d13e4f7794..5c9bc62cb09b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -26,11 +25,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 91fa9b0d3590..26c6235fe590 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,15 +23,14 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SQLContext} /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7384d065a2ea..2fce3ff64110 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} +import breeze.linalg.{argmax, argtopk, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 17c0609800e9..c19595e6cd21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} -import breeze.numerics.{trigamma, abs, exp} +import breeze.linalg.{all, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics.{abs, exp, trigamma} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index a9ba7b60bad0..647d37bd822c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.linalg.{max, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics._ /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index b2f140e1b135..c9a96c68667a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.clustering import scala.util.Random import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.Vectors /** * An utility object to run K-means locally. This is private to the ML package because it's used diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index bb1804505948..2ab0920b0636 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.clustering -import org.json4s.JsonDSL._ import org.json4s._ +import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ @@ -30,7 +31,6 @@ import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.{Logging, SparkContext, SparkException} /** * Model produced by [[PowerIterationClustering]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 80843719f50b..79d217e183c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 078fbfbe4f0e..f0779491e637 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * Computes the area under the curve (AUC) using the trapezoidal rule. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index cc01936dd34b..f8de4e2220c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 1d8f4fe340fb..34883f2f390d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.sql.DataFrame /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index eaa99cfe82e2..33728bf5d77e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Chi Squared selector model. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1f400e1430eb..a7e1b76df6a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -36,9 +35,9 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.sql.SQLContext /** * Entry in vocabulary diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 97916daa2e9a..ed49c9492fdc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.fpm import java.{lang => jl, util => ju} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import org.apache.spark.Logging diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 72d3aabc9b1f..57ca4d3464f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 863abe86d38d..bb94745f078e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import com.github.fommil.netlib.ARPACK -import org.netlib.util.{intW, doubleW} +import org.netlib.util.{doubleW, intW} /** * Compute eigen-decomposition. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 8879dcf75c9b..d7a74db0b1fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import java.util.{Arrays, Random} -import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet} import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4dcf351df43f..cecfd067bd87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.linalg -import java.util import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson} +import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} import org.apache.spark.SparkException import org.apache.spark.annotation.{AlphaComponent, Since} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 8a70f34e70f6..97b03b340f20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} +import org.apache.spark.rdd.RDD /** * Represents an entry in an distributed matrix. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 976299124ced..e8de515211a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.rdd.RDD /** * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 2018a678688e..0a36da410133 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -21,8 +21,8 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd, MatrixSingularException, inv} +import breeze.linalg.{axpy => brzAxpy, inv, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV, + MatrixSingularException, SparseVector => BSV} import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.Logging @@ -30,8 +30,8 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD -import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.XORShiftRandom /** * Represents a row-oriented distributed Matrix with no meaningful row indices. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 37bb6f6097f6..5873669b37e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,12 +19,12 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, norm} +import breeze.linalg.{norm, DenseVector => BDV} -import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 7f6d94571b5e..d8e56720967d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.rdd.RDD - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 9f463e0cafb6..03c01e0553d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.optimization import scala.math._ -import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} +import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9eab7efc160d..fa04f8eb5e79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution._ -import org.apache.spark.annotation.{Since, DeveloperApi} -import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.util.random.{Pseudorandom, XORShiftRandom} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f8cea7ecea6b..92bc66949ae8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -17,15 +17,15 @@ package org.apache.spark.mllib.rdd +import scala.reflect.ClassTag +import scala.util.Random + import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag -import scala.util.Random - private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, val generator: RandomDataGenerator[T], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index ead8db634499..adb5e51947f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.rdd import scala.collection.mutable import scala.reflect.ClassTag -import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD private[mllib] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8f657bfb9c73..e60edc675c83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.regression +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index c284ad232537..45540f0c5c4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index a9aba173fa0e..d55e5dfdaaf5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 4996ace5df85..7da82c862a2b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 201333c3690d..98404be2603c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index bcb33a7a0467..f3159f7e724c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} -import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 0724af93088c..052b5b1d65b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} +import breeze.linalg.{diag, eigSym, max, DenseMatrix => DBM, DenseVector => DBV, Vector => BV} import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLUtils /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 23c8d7c7c807..f22f2df320f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.stat.test +import scala.collection.mutable + import breeze.linalg.{DenseMatrix => BDM} import org.apache.commons.math3.distribution.ChiSquaredDistribution -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import scala.collection.mutable - /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index d2513a9d5c5b..0b118a76733f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,7 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} +import org.apache.spark.mllib.tree.loss.{LogLoss, Loss, SquaredError} /** * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 372d6617a401..6c04403f1ad7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -21,9 +21,9 @@ import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since -import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} /** * Stores all the configuration options for tree construction diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 1c611976a930..fbbec1197404 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.rdd.RDD import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.tree.model.{Bin, Node, Split} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ea6e5aa5d94e..66f0908c1250 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging -import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.FeatureType._ /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b85a66c05a81..783a4acb55ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 33477ee20ebb..68835bc79677 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 906bd30563bd..8af6750da4ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4c9151f0cb4f..89186de96988 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import org.apache.spark.annotation.Since import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliCellSampler -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.BernoulliCellSampler /** * Helper methods to load, save and pre-process data used in ML Lib. diff --git a/pom.xml b/pom.xml index 62ea829b1dbf..398fcc92db99 100644 --- a/pom.xml +++ b/pom.xml @@ -1951,6 +1951,11 @@ + + org.antlr + antlr3-maven-plugin + 3.5.2 + org.apache.maven.plugins diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 519052620246..9ba9f8286f10 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.5.0" + val previousSparkVersion = "1.6.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 59886ab76244..7a6e5cf4ad39 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -30,160 +30,29 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * It is also possible to exclude Spark classes and packages. This should be used sparingly: * * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") + * + * For a new Spark version, please update MimaBuild.scala to reflect the previous version. */ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("2.0") => Seq( - // SPARK-7995 Remove AkkaRpcEnv - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaFailure"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaFailure$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEndpointRef$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEnvFactory"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEnv"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaMessage$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEndpointRef"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.ErrorMonitor"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaMessage") + excludePackage("org.apache.spark.rpc"), + excludePackage("org.spark-project.jetty"), + excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.sql.catalyst"), + excludePackage("org.apache.spark.sql.execution"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this") ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ - // When 1.6 is officially released, update this exclusion list. Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("network"), - MimaBuild.excludeSparkPackage("unsafe"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // SQL columnar is considered private. - excludePackage("org.apache.spark.sql.columnar"), - // The shuffle package is considered private. - excludePackage("org.apache.spark.shuffle"), - // The collections utlities are considered pricate. - excludePackage("org.apache.spark.util.collection") - ) ++ - MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ - Seq( - // MiMa does not deal properly with sealed traits - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") - ) ++ Seq( - // SPARK-11530 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this") - ) ++ Seq( - // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. - // This class is marked as `private` but MiMa still seems to be confused by the change. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ Seq( - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.SQLContext$SQLSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.detachSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.tlSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.defaultSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.currentSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.openSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.setSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.createSession") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.preferredNodeLocationData_="), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") - ) ++ Seq( - // SPARK-11485 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), - // SPARK-11541 mark various JDBC dialects as private - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") - ) ++ Seq ( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationInfo.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.StageData.this") - ) ++ Seq( - // SPARK-11766 add toJson to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toJson") - ) ++ Seq( - // SPARK-9065 Support message handler in Kafka Python API - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") - ) ++ Seq( - // SPARK-4557 Changed foreachRDD to use VoidFunction - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") - ) ++ Seq( - // SPARK-11996 Make the executor thread dump work again - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") - ) ++ Seq( - // SPARK-3580 Add getNumPartitions method to JavaRDD - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") - ) ++ - // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a - // private class. - MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") + // SPARK-12481 Remove Hadoop 1.x + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.mapred.SparkHadoopMapRedUtil") + ) case v if v.startsWith("1.6") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c3d53f835f39..588e97f64e05 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -414,9 +414,51 @@ object Hive { // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new // new query tests. - fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } - ) + fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") }, + // ANTLR code-generation step. + // + // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of + // build errors in the current plugin. + // Create Parser from ANTLR grammar files. + sourceGenerators in Compile += Def.task { + val log = streams.value.log + + val grammarFileNames = Seq( + "SparkSqlLexer.g", + "SparkSqlParser.g") + val sourceDir = (sourceDirectory in Compile).value / "antlr3" + val targetDir = (sourceManaged in Compile).value + + // Create default ANTLR Tool. + val antlr = new org.antlr.Tool + + // Setup input and output directories. + antlr.setInputDirectory(sourceDir.getPath) + antlr.setOutputDirectory(targetDir.getPath) + antlr.setForceRelativeOutput(true) + antlr.setMake(true) + + // Add grammar files. + grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => + val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath + log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) + antlr.addGrammarFile(relGFilePath) + } + // Generate the parser. + antlr.process + if (antlr.getNumErrors > 0) { + log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) + } + + // Return all generated java files. + (targetDir ** "*.java").get.toSeq + }.taskValue, + // Include ANTLR tokens files. + resourceGenerators in Compile += Def.task { + ((sourceManaged in Compile).value ** "*.tokens").get.toSeq + }.taskValue + ) } object Assembly { diff --git a/project/plugins.sbt b/project/plugins.sbt index 5e23224cf8aa..15ba3a36d51c 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -27,3 +27,5 @@ addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" + +libraryDependencies += "org.antlr" % "antlr" % "3.5.2" diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a3d7eca04b61..a2771daabe33 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -160,6 +160,8 @@ def json(self, path, schema=None): quotes * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ (e.g. 00012) + * ``allowBackslashEscapingAnyCharacter`` (default ``false``): allows accepting quoting \ + of all character using backslash quoting mechanism >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 10b99175ad95..9ada96601a1c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -360,7 +360,7 @@ def test_infer_schema_to_local(self): df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 6925e18737b7..ee855ca0e09c 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -187,14 +187,6 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods - - - ^getConfiguration$|^getTaskAttemptID$ - Instead of calling .getConfiguration() or .getTaskAttemptID() directly, - use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. - - - @@ -219,8 +211,8 @@ This file is divided into 3 sections: java,scala,3rdParty,spark - javax?\..+ - scala\..+ + javax?\..* + scala\..* (?!org\.apache\.spark\.).* org\.apache\.spark\..* diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 8f4ce74a2ea3..3b775c3ca87b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -104,7 +104,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { val tableName = getTableName(tableIdent) val table = tables.get(tableName) if (table == null) { - throw new NoSuchTableException + throw new AnalysisException("Table not found: " + tableName) } val tableWithQualifiers = Subquery(tableName, table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8..97fb3c1dd895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -45,6 +45,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { SetOperationPushDown, SamplePushDown, ReorderJoin, + OuterJoinElimination, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -768,6 +769,79 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Elimination of outer joins, if the predicates can restrict the result sets so that + * all null-supplying rows are eliminated + * + * - full outer -> inner if both sides have such predicates + * - left outer -> inner if the right side has such predicates + * - right outer -> inner if the left side has such predicates + * - full outer -> left outer if only the left side has such predicates + * - full outer -> right outer if only the right side has such predicates + * + * This rule should be executed before pushing down the Filter + */ +object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { + + private def containsAttr(plan: LogicalPlan, attr: Attribute): Boolean = + plan.outputSet.exists(_.semanticEquals(attr)) + + private def hasNullFilteringPredicate(predicate: Expression, plan: LogicalPlan): Boolean = { + predicate match { + case EqualTo(ar: AttributeReference, _) if containsAttr(plan, ar) => true + case EqualTo(_, ar: AttributeReference) if containsAttr(plan, ar) => true + case EqualNullSafe(ar: AttributeReference, l) + if !l.nullable && containsAttr(plan, ar) => true + case EqualNullSafe(l, ar: AttributeReference) + if !l.nullable && containsAttr(plan, ar) => true + case GreaterThan(ar: AttributeReference, _) if containsAttr(plan, ar) => true + case GreaterThan(_, ar: AttributeReference) if containsAttr(plan, ar) => true + case GreaterThanOrEqual(ar: AttributeReference, _) if containsAttr(plan, ar) => true + case GreaterThanOrEqual(_, ar: AttributeReference) if containsAttr(plan, ar) => true + case LessThan(ar: AttributeReference, _) if containsAttr(plan, ar) => true + case LessThan(_, ar: AttributeReference) if containsAttr(plan, ar) => true + case LessThanOrEqual(ar: AttributeReference, _) if containsAttr(plan, ar) => true + case LessThanOrEqual(_, ar: AttributeReference) if containsAttr(plan, ar) => true + case In(ar: AttributeReference, _) if containsAttr(plan, ar) => true + case IsNotNull(ar: AttributeReference) if containsAttr(plan, ar) => true + case And(l, r) => hasNullFilteringPredicate(l, plan) || hasNullFilteringPredicate(r, plan) + case Or(l, r) => hasNullFilteringPredicate(l, plan) && hasNullFilteringPredicate(r, plan) + case Not(e) => !hasNullFilteringPredicate(e, plan) + case _ => false + } + } + + private def buildNewJoin( + otherCondition: Expression, + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]): Join = { + val leftHasNonNullPredicate = hasNullFilteringPredicate(otherCondition, left) + val rightHasNonNullPredicate = hasNullFilteringPredicate(otherCondition, right) + + joinType match { + case RightOuter if leftHasNonNullPredicate => + Join(left, right, Inner, condition) + case LeftOuter if rightHasNonNullPredicate => + Join(left, right, Inner, condition) + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => + Join(left, right, Inner, condition) + case FullOuter if leftHasNonNullPredicate => + Join(left, right, LeftOuter, condition) + case FullOuter if rightHasNonNullPredicate => + Join(left, right, RightOuter, condition) + case _ => Join(left, right, joinType, condition) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Only three outer join types are eligible: RightOuter|LeftOuter|FullOuter + case f @ Filter(filterCond, j @ Join(left, right, RightOuter|LeftOuter|FullOuter, joinCond)) => + Filter(filterCond, buildNewJoin(filterCond, left, right, j.joinType, joinCond)) + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 0ea51ece4bc5..8f4faab7bace 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -108,7 +108,8 @@ class RowEncoderSuite extends SparkFunSuite { .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) .add("arrayOfMap", ArrayType(mapOfString)) - .add("arrayOfStruct", ArrayType(structOfString))) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) encodeDecodeTest( new StructType() @@ -130,18 +131,6 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) - test(s"encode/decode: arrayOfUDT") { - val schema = new StructType() - .add("arrayOfUDT", arrayOfUDT) - - val encoder = RowEncoder(schema) - - val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) - assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) - } - test(s"encode/decode: Product") { val schema = new StructType() .add("structAsProduct", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala new file mode 100644 index 000000000000..985cfdfc1541 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class OuterJoinEliminationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Outer Join Elimination", Once, + OuterJoinElimination, + PushPredicateThroughJoin) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int, 'e.int, 'f.int) + + test("joins: full outer to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr >= 1 && "y.d".attr >= 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b >= 1) + val right = testRelation1.where('d >= 2) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: full outer to right") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("y.d".attr > 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('d > 2) + val correctAnswer = + left.join(right, RightOuter, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: full outer to left") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("x.a".attr <=> 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <=> 2) + val right = testRelation1 + val correctAnswer = + left.join(right, LeftOuter, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: right to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, RightOuter, Option("x.a".attr === "y.d".attr)).where("x.b".attr > 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b > 2) + val right = testRelation1 + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: left to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where("y.e".attr.isNotNull) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('e.isNotNull) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: left to inner with complicated filter predicates") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0acea95344c2..6debb302d9ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -258,6 +258,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • + *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all + * character using backslash quoting mechanism
  • * * @since 1.6.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index eadf5cba6d9b..022303239f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -904,8 +904,7 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Add row converters", Once, EnsureRowFormats) + Batch("Add exchange", Once, EnsureRequirements(self)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 62cbc518e02a..7b4161930b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -50,26 +49,14 @@ case class Exchange( case None => "" } - val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" + val simpleNodeName = "Exchange" s"$simpleNodeName$extraInfo" } - /** - * Returns true iff we can support the data type, and we are not doing range partitioning. - */ - private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] - override def outputPartitioning: Partitioning = newPartitioning override def output: Seq[Attribute] = child.output - // This setting is somewhat counterintuitive: - // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, - // so the planner inserts a converter to convert data into UnsafeRow if needed. - override def outputsUnsafeRows: Boolean = tungstenMode - override def canProcessSafeRows: Boolean = !tungstenMode - override def canProcessUnsafeRows: Boolean = tungstenMode - /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -130,15 +117,7 @@ case class Exchange( } } - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - - private val serializer: Serializer = { - if (tungstenMode) { - new UnsafeRowSerializer(child.output.size) - } else { - new SparkSqlSerializer(sparkConf) - } - } + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) override protected def doPrepare(): Unit = { // If an ExchangeCoordinator is needed, we register this Exchange operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 5c01af011d30..fc508bfafa1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, AttributeSet, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -99,10 +99,19 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], override val nodeName: String, override val metadata: Map[String, String] = Map.empty, - override val outputsUnsafeRows: Boolean = false) + isUnsafeRow: Boolean = false) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = rdd + protected override def doExecute(): RDD[InternalRow] = { + if (isUnsafeRow) { + rdd + } else { + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } + } + } override def simpleString: String = { val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 91530bd63798..c3683cc4e7aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,20 +41,11 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - private[this] val projection = { - if (outputsUnsafeRows) { - (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) - } else { - (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() - } - } + private[this] val projection = + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 0c613e91b979..4db88a09d815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -64,6 +64,7 @@ case class Generate( child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow + val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -77,13 +78,14 @@ case class Generate( } ++ LazyIterator(() => boundGenerator.terminate()).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - joinedRow.withRight(row) + proj(joinedRow.withRight(row)) } } } else { child.execute().mapPartitionsInternal { iter => - iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate()) + val proj = UnsafeProjection.create(output, output) + (iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate())).map(proj) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index ba7f6287ac6c..59057bf9666e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} /** @@ -29,15 +29,20 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + private val unsafeRows: Array[InternalRow] = { + val proj = UnsafeProjection.create(output, output) + rows.map(r => proj(r).copy()).toArray + } + + private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[InternalRow] = { - rows.toArray + unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - rows.take(limit).toArray + unsafeRows.take(limit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 24207cb46fd2..73dc8cb98447 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -39,10 +39,6 @@ case class Sort( testSpillFrequency: Int = 0) extends UnaryNode { - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fe9b2ad4a0bc..f20f32aaced2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,17 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute @@ -115,18 +104,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { - if (children.nonEmpty) { - val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) - val hasSafeInputs = children.exists(!_.outputsUnsafeRows) - assert(!(hasSafeInputs && hasUnsafeInputs), - "Child operators should output rows in the same format") - assert(canProcessSafeRows || canProcessUnsafeRows, - "Operator must be able to process at least one row format") - assert(!hasSafeInputs || canProcessSafeRows, - "Operator will receive safe rows as input but cannot process safe rows") - assert(!hasUnsafeInputs || canProcessUnsafeRows, - "Operator will receive unsafe rows as input but cannot process unsafe rows") - } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() doExecute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index c941d673c724..b79d93d7ca4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -100,8 +100,6 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def canProcessUnsafeRows: Boolean = true - /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -259,16 +257,16 @@ case class Window( * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): MutableProjection = { + expressions: Seq[Expression]): UnsafeProjection = { val references = expressions.zipWithIndex.map{ case (e, i) => // Results of window expressions will be on the right side of child's output BoundReference(child.output.size + i, e.dataType, e.nullable) } val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - newMutableProjection( + UnsafeProjection.create( projectList ++ patchedWindowExpression, - child.output)() + child.output) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c4587ba677b2..01d076678f04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -49,10 +49,6 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def requiredChildDistribution: List[Distribution] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index ac920aa8bc7f..6501634ff998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,6 +87,10 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be + // compared to MutableRow (aggregation buffer) directly. + private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -110,7 +114,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -122,7 +126,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, currentRow) + processRow(sortBasedAggregationBuffer, safeProj(currentRow)) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9d758eb3b7c3..999ebb768af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -49,10 +49,6 @@ case class TungstenAggregate( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def producedAttributes: AttributeSet = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f19d72f06721..af7237ef2588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -36,10 +36,6 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -80,12 +76,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - - override def canProcessUnsafeRows: Boolean = true - - override def canProcessSafeRows: Boolean = true } /** @@ -108,10 +98,6 @@ case class Sample( { override def output: Seq[Attribute] = child.output - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -135,8 +121,6 @@ case class Range( output: Seq[Attribute]) extends LeafNode { - override def outputsUnsafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { sqlContext .sparkContext @@ -199,9 +183,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -268,12 +249,14 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. - @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) - private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - projection.map(data.map(_)).getOrElse(data) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } } override def executeCollect(): Array[InternalRow] = { @@ -311,10 +294,6 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { child.execute().coalesce(numPartitions, shuffle = false) } - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -327,10 +306,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -343,10 +318,6 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -371,10 +342,6 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) @@ -394,11 +361,6 @@ case class AppendColumns[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = AttributeSet(newColumns) - // We are using an unsafe combiner. - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { @@ -428,10 +390,6 @@ case class MapGroups[K, T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -472,10 +430,6 @@ case class CoGroup[Key, Left, Right, Result]( right: SparkPlan) extends BinaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index aa7a668e0e93..d80912309bab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -39,9 +39,7 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, - if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), - tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } /** @@ -226,8 +224,6 @@ private[sql] case class InMemoryColumnarTableScan( // The cached version does not change the outputOrdering of the original SparkPlan. override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering - override def outputsUnsafeRows: Boolean = true - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 735d52f80886..758bcd706a8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -93,7 +93,7 @@ private[sql] case class InsertIntoHadoopFsRelation( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - val job = new Job(hadoopConf) + val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index eea780cbaa7e..12f8783f846d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -26,10 +26,10 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel @@ -68,16 +68,14 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sqlContext.sparkContext, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[V](sqlContext.sparkContext, Nil) with Logging { protected def getJob(): Job = { - val conf: Configuration = broadcastedConf.value.value + val conf = broadcastedConf.value.value // "new Job" will make a copy of the conf. Then, it is // safe to mutate conf properties with initLocalJobFuncOpt // and initDriverSideJobFuncOpt. - val newJob = new Job(conf) + val newJob = Job.getInstance(conf) initLocalJobFuncOpt.map(f => f(newJob)) newJob } @@ -87,7 +85,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - SparkHadoopUtil.get.getConfigurationFromJobContext(job) + job.getConfiguration } private val jobTrackerId: String = { @@ -110,7 +108,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[SparkPartition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -154,8 +152,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private[this] var reader: RecordReader[Void, V] = null /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 983f4df1de36..8b0b64774455 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -24,10 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow @@ -41,14 +41,12 @@ private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, @transient private val job: Job, isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { + extends Logging with Serializable { protected val dataSchema = relation.dataSchema protected val serializableConf = - new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + new SerializableConfiguration(job.getConfiguration) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -90,8 +88,7 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - SparkHadoopUtil.get.getConfigurationFromJobContext(job). - set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -101,7 +98,7 @@ private[sql] abstract class BaseWriterContainer( // committer, since their initialization involve the job configuration, which can be potentially // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) @@ -111,7 +108,7 @@ private[sql] abstract class BaseWriterContainer( def executorSideSetup(taskContext: TaskContext): Unit = { setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupTask(taskAttemptContext) } @@ -166,7 +163,7 @@ private[sql] abstract class BaseWriterContainer( "because spark.speculation is configured to be true.") defaultOutputCommitter } else { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val committerClass = configuration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) @@ -201,10 +198,8 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - // scalastyle:off jobcontext + this.taskId = new TaskID(this.jobId, TaskType.MAP, splitId) this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - // scalastyle:on jobcontext } private def setupConf(): Unit = { @@ -250,7 +245,7 @@ private[sql] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) @@ -421,7 +416,7 @@ private[sql] class DynamicPartitionWriterContainer( def newOutputWriter(key: InternalRow): OutputWriter = { val partitionPath = getPartitionString(key).getString(0) val path = new Path(getWorkPath, partitionPath) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + val configuration = taskAttemptContext.getConfiguration configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) val newWriter = super.newOutputWriter(path.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 7072ee4b4e3b..87d43addd36c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -179,7 +179,7 @@ private[sql] object JDBCRDD extends Logging { case stringValue: String => s"'${escapeSql(stringValue)}'" case timestampValue: Timestamp => "'" + timestampValue + "'" case dateValue: Date => "'" + dateValue + "'" - case arrayValue: Array[Object] => arrayValue.map(compileValue).mkString(", ") + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") case _ => value } @@ -188,24 +188,44 @@ private[sql] object JDBCRDD extends Logging { /** * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. + * Returns None for an unhandled filter. */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case Not(f) => s"(NOT (${compileFilter(f)}))" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" - case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" - case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case In(attr, value) => s"$attr IN (${compileValue(value)})" - case Or(f1, f2) => s"(${compileFilter(f1)}) OR (${compileFilter(f2)})" - case And(f1, f2) => s"(${compileFilter(f1)}) AND (${compileFilter(f2)})" - case _ => null + private def compileFilter(f: Filter): Option[String] = { + Option(f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualNullSafe(attr, value) => + s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + + s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" + case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" + case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null) + case Or(f1, f2) => + // We can't compile Or filter unless both sub-filters are compiled successfully. + // It applies too for the following And filter. + // If we can make sure compileFilter supports all filters, we can remove this check. + val or = Seq(f1, f2).map(compileFilter(_)).flatten + if (or.size == 2) { + or.map(p => s"($p)").mkString(" OR ") + } else { + null + } + case And(f1, f2) => + val and = Seq(f1, f2).map(compileFilter(_)).flatten + if (and.size == 2) { + and.map(p => s"($p)").mkString(" AND ") + } else { + null + } + case _ => null + }) } /** @@ -303,29 +323,24 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } - /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { - val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } + private val filterWhereClause: String = + filters.map(JDBCRDD.compileFilter).flatten.mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. */ private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause + "WHERE " + filterWhereClause + " AND " + part.whereClause } else if (part.whereClause != null) { "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause } else { - filterWhereClause + "" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index c132ead20e7d..f805c0092585 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -31,7 +31,8 @@ case class JSONOptions( allowUnquotedFieldNames: Boolean = false, allowSingleQuotes: Boolean = true, allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false) { + allowNonNumericNumbers: Boolean = false, + allowBackslashEscapingAnyCharacter: Boolean = false) { /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -40,6 +41,8 @@ case class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + factory.configure(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + allowBackslashEscapingAnyCharacter) } } @@ -59,6 +62,8 @@ object JSONOptions { allowNumericLeadingZeros = parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), + allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 3e61ba35bea8..54a8552134c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -30,8 +30,6 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection @@ -89,8 +87,8 @@ private[sql] class JSONRelation( override val needConversion: Boolean = false private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath) @@ -176,7 +174,7 @@ private[json] class JsonOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with Logging { + extends OutputWriter with Logging { private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records @@ -186,9 +184,9 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index a958373eb769..e5d8e6088b39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -58,9 +58,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with */ override def init(context: InitContext): ReadContext = { catalystRequestedSchema = { - // scalastyle:off jobcontext val conf = context.getConfiguration - // scalastyle:on jobcontext val schemaString = conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1a4e99ff10af..e54f51e3830f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -54,11 +54,7 @@ private[datasources] class DirectParquetOutputCommitter( override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(jobContext) - // scalastyle:on jobcontext - } + val configuration = ContextUtil.getConfiguration(jobContext) val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1af2a394f399..af964b4d3561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName @@ -40,7 +41,6 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -82,9 +82,9 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } @@ -217,11 +217,7 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(job) - // scalastyle:on jobcontext - } + val conf = ContextUtil.getConfiguration(job) // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -340,7 +336,7 @@ private[sql] class ParquetRelation( // URI of the path to create a new Path. val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime, f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) }.toSeq @@ -359,7 +355,7 @@ private[sql] class ParquetRelation( } } - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val jobContext = new JobContextImpl(getConf(isDriverSide = true), jobId) val rawSplits = inputFormat.getSplits(jobContext) Array.tabulate[SparkPartition](rawSplits.size) { i => @@ -564,7 +560,7 @@ private[sql] object ParquetRelation extends Logging { parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val conf = job.getConfiguration conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -607,7 +603,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + overrideMinSplitSize(parquetBlockSize, job.getConfiguration) } private[parquet] def readSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 41fcb11d84bf..248467abe9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -26,8 +26,6 @@ import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -88,8 +86,8 @@ private[sql] class TextRelation( filters: Array[Filter], inputPaths: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath).sortBy(_.toUri) if (paths.nonEmpty) { @@ -138,17 +136,16 @@ private[sql] class TextRelation( } class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter - with SparkHadoopMapRedUtil { + extends OutputWriter { private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index aab177b2e842..54275c2cc113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -46,15 +46,8 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } - override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 81bfe4e67ca7..d9fa4c6b8379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -81,10 +81,6 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb961d97c3c3..7f9d9daa5ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,10 +44,6 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildSideKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index c6e586818751..6d464d6946b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,10 +64,6 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index f23a1830e91c..3e0f74cd98c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,10 +33,6 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def leftKeyGenerator: Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index efa7b49410ed..82498ee39564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -42,9 +42,6 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4bf7b521c77d..812f881d06fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,10 +53,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 7ce38ebdb341..c3a2bfc59c7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -89,10 +89,6 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def createLeftKeyGenerator(): Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 6a882c9234df..e46217050bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,18 +69,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin */ def close(): Unit - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true - /** * Returns the content through the [[Iterator]] interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index 7321fc66b4dd..b7fa0c020222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -47,11 +47,7 @@ case class NestedLoopJoinNode( } private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } + UnsafeProjection.create(schema) } private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index defcec95fb55..efb4b09c1634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -351,10 +351,6 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -400,13 +396,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val unpickle = new Unpickler val row = new GenericMutableRow(1) val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - joined(queue.poll(), row) + resultProj(joined(queue.poll(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala deleted file mode 100644 index 5f8fc2de8b46..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Converts Java-object-based rows into [[UnsafeRow]]s. - */ -case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToUnsafe = UnsafeProjection.create(child.schema) - iter.map(convertToUnsafe) - } - } -} - -/** - * Converts [[UnsafeRow]]s back into Java-object-based rows. - */ -case class ConvertToSafe(child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) - iter.map(convertToSafe) - } - } -} - -private[sql] object EnsureRowFormats extends Rule[SparkPlan] { - - private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && !operator.canProcessUnsafeRows - - private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessUnsafeRows && !operator.canProcessSafeRows - - private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && operator.canProcessUnsafeRows - - override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { - case operator: SparkPlan if onlyHandlesSafeRows(operator) => - if (operator.children.exists(_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => - if (operator.children.exists(!_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => - if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows. - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c5aed6da9c..3572f3c3a1f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2843,6 +2843,20 @@ object functions extends LegacyFunctions { // scalastyle:on parameter.number // scalastyle:on line.size.limit + /** + * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must + * specifcy the output data type, and there is no automatic input type coercion. + * + * @param f A closure in Scala + * @param dataType The output data type of the UDF + * + * @group udf_funcs + * @since 2.0.0 + */ + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + UserDefinedFunction(f, dataType, None) + } + /** * Call an user-defined function. * Example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index fc8ce6901dfc..d6c5d1435702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -462,7 +462,7 @@ abstract class HadoopFsRelation private[sql]( name.toLowerCase == "_temporary" || name.startsWith(".") } - val (dirs, files) = statuses.partition(_.isDir) + val (dirs, files) = statuses.partition(_.isDirectory) // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) if (dirs.isEmpty) { @@ -858,10 +858,10 @@ private[sql] object HadoopFsRelation extends Logging { val jobConf = new JobConf(fs.getConf, this.getClass()) val pathFilter = FileInputFormat.getInputPathFilter(jobConf) if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } @@ -896,7 +896,7 @@ private[sql] object HadoopFsRelation extends Logging { FakeFileStatus( status.getPath.toString, status.getLen, - status.isDir, + status.isDirectory, status.getReplication, status.getBlockSize, status.getModificationTime, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index d86df4cfb9b4..6b735bcf1610 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException + import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD @@ -289,7 +289,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext testData.select('key).registerTempTable("t1") sqlContext.table("t1") sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) } test("Drops cached temporary table") { @@ -301,7 +301,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(sqlContext.isCached("t2")) sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) assert(!sqlContext.isCached("t2")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39a65413bd59..0cbddf4d37d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -140,4 +142,50 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(df1.join(broadcast(pf1)).count() === 4) } } + + test("join - outer join conversion") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + + // outer -> left + val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) + assert(outerJoin2Left.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, LeftOuter, _) => j }.size === 1) + checkAnswer( + outerJoin2Left, + Row(3, 4, "3", null, null, null) :: Nil) + + // outer -> right + val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) + assert(outerJoin2Right.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, RightOuter, _) => j }.size === 1) + checkAnswer( + outerJoin2Right, + Row(null, null, null, 5, 6, "5") :: Nil) + + // outer -> inner + val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer"). + where($"a.int" === 1 && $"b.int2" === 3) + assert(outerJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + outerJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + + // right -> inner + val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) + assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + rightJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + + // left -> inner + val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) + assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + leftJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9a3c262e9485..935dade233a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -82,7 +82,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeOuterJoin]), + classOf[SortMergeJoin]), // conversion from Right Outer to Inner ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", @@ -123,8 +123,12 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where a = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and a = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 911d12e93e50..87bff3295f5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -28,7 +28,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + plan => Exchange(SinglePartition, plan), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala deleted file mode 100644 index faef76d52ae7..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.IntegerType - -class ExpandSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - private def testExpand(f: SparkPlan => SparkPlan): Unit = { - val input = (1 to 1000).map(Tuple1.apply) - val projections = Seq.tabulate(2) { i => - Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil - } - val attributes = projections.head.map(_.toAttribute) - checkAnswer( - input.toDF(), - plan => Expand(projections, attributes, f(plan)), - input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) - ) - } - - test("inheriting child row type") { - val exprs = AttributeReference("a", IntegerType, false)() :: Nil - val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) - assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") - } - - test("expanding UnsafeRows") { - testExpand(ConvertToUnsafe) - } - - test("expanding SafeRows") { - testExpand(identity) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala deleted file mode 100644 index 2328899bb2f8..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{ArrayType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { - - private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { - case c: ConvertToUnsafe => c - case c: ConvertToSafe => c - } - - private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(outputsUnsafe.outputsUnsafeRows) - - test("planner should insert unsafe->safe conversions when required") { - val plan = Limit(10, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) - } - - test("filter can process unsafe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("filter can process safe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("coalesce can process unsafe rows") { - val plan = Coalesce(1, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except can process unsafe rows") { - val plan = Except(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except requires all of its input rows' formats to agree") { - val plan = Except(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect can process unsafe rows") { - val plan = Intersect(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect requires all of its input rows' formats to agree") { - val plan = Intersect(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("execute() fails an assertion if inputs rows are of different formats") { - val e = intercept[AssertionError] { - Union(Seq(outputsSafe, outputsUnsafe)).execute() - } - assert(e.getMessage.contains("format")) - } - - test("union requires all of its input rows' formats to agree") { - val plan = Union(Seq(outputsSafe, outputsUnsafe)) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("union can process safe rows") { - val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("union can process unsafe rows") { - val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("round trip with ConvertToUnsafe and ConvertToSafe") { - val input = Seq(("hello", 1), ("world", 2)) - checkAnswer( - sqlContext.createDataFrame(input), - plan => ConvertToSafe(ConvertToUnsafe(plan)), - input.map(Row.fromTuple) - ) - } - - test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SQLContext.setActive(sqlContext) - val schema = ArrayType(StringType) - val rows = (1 to 100).map { i => - InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) - } - val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) - - val plan = - DummyPlan( - ConvertToSafe( - ConvertToUnsafe(relation))) - assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) - } -} - -case class DummyPlan(child: SparkPlan) extends UnaryNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some - // values gotten from the incoming rows. - // we cache all strings here to make sure we have deep copied UTF8String inside incoming - // safe InternalRow. - val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] - iter.foreach { row => - strings += row.getArray(0).getUTF8String(0) - } - strings.map(InternalRow(_)).iterator - } - } - - override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index e5d34be4c65e..af971dfc6fae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -99,7 +99,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { ) checkThatPlansAgree( inputDf, - p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), ReferenceSort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 4cc0a3a9585d..1742df31bba9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -111,4 +111,23 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) } + + test("allowBackslashEscapingAnyCharacter off") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowBackslashEscapingAnyCharacter on") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.schema.last.name == "price") + assert(df.first().getString(0) == "Cazen Lee") + assert(df.first().getString(1) == "$10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 00e37f107a88..dae72e8acb5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -185,12 +185,13 @@ class JDBCSuite extends SparkFunSuite assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) - .collect().size === 2) + .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " @@ -453,8 +454,8 @@ class JDBCSuite extends SparkFunSuite } test("compile filters") { - val compileFilter = PrivateMethod[String]('compileFilter) - def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) + val compileFilter = PrivateMethod[Option[String]]('compileFilter) + def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("") assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) @@ -473,6 +474,9 @@ class JDBCSuite extends SparkFunSuite === "(NOT (col1 IN ('mno', 'pqr')))") assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) + === "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) " + + "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')") } test("Dialect unregister") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2d0d7b8af358..2b0e48dbfcf2 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -308,7 +308,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The difference between the double numbers generated by Hive and Spark // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) - "udaf_corr" + "udaf_corr", + + // Feature removed in HIVE-11145 + "alter_partition_protect_mode", + "drop_partitions_ignore_protection", + "protectmode" ) /** @@ -328,7 +333,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_index", "alter_merge_2", "alter_partition_format_loc", - "alter_partition_protect_mode", "alter_partition_with_whitelist", "alter_rename_partition", "alter_table_serde", @@ -460,7 +464,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_filter", "drop_partitions_filter2", "drop_partitions_filter3", - "drop_partitions_ignore_protection", "drop_table", "drop_table2", "drop_table_removes_partition_dirs", @@ -778,7 +781,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", - "protectmode", "push_or", "query_with_semi", "quote1", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e9885f668202..ffabb92179a1 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -232,6 +232,7 @@ v${hive.version.short}/src/main/scala + ${project.build.directory/generated-sources/antlr @@ -260,6 +261,27 @@
    + + + + org.antlr + antlr3-maven-plugin + + + + antlr + + + + + ${basedir}/src/main/antlr3 + + **/SparkSqlLexer.g + **/SparkSqlParser.g + + + + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g new file mode 100644 index 000000000000..e4a80f0ce8eb --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g @@ -0,0 +1,330 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar FromClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +tableAllColumns + : STAR + -> ^(TOK_ALLCOLREF) + | tableName DOT STAR + -> ^(TOK_ALLCOLREF tableName) + ; + +// (table|column) +tableOrColumn +@init { gParent.pushMsg("table or column identifier", state); } +@after { gParent.popMsg(state); } + : + identifier -> ^(TOK_TABLE_OR_COL identifier) + ; + +expressionList +@init { gParent.pushMsg("expression list", state); } +@after { gParent.popMsg(state); } + : + expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) + ; + +aliasList +@init { gParent.pushMsg("alias list", state); } +@after { gParent.popMsg(state); } + : + identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) + ; + +//----------------------- Rules for parsing fromClause ------------------------------ +// from [col1, col2, col3] table1, [col4, col5] table2 +fromClause +@init { gParent.pushMsg("from clause", state); } +@after { gParent.popMsg(state); } + : + KW_FROM joinSource -> ^(TOK_FROM joinSource) + ; + +joinSource +@init { gParent.pushMsg("join source", state); } +@after { gParent.popMsg(state); } + : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )* + | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ + ; + +uniqueJoinSource +@init { gParent.pushMsg("unique join source", state); } +@after { gParent.popMsg(state); } + : KW_PRESERVE? fromSource uniqueJoinExpr + ; + +uniqueJoinExpr +@init { gParent.pushMsg("unique join expression list", state); } +@after { gParent.popMsg(state); } + : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN + -> ^(TOK_EXPLIST $e1*) + ; + +uniqueJoinToken +@init { gParent.pushMsg("unique join", state); } +@after { gParent.popMsg(state); } + : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; + +joinToken +@init { gParent.pushMsg("join type specifier", state); } +@after { gParent.popMsg(state); } + : + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + ; + +lateralView +@init {gParent.pushMsg("lateral view", state); } +@after {gParent.popMsg(state); } + : + (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + | + KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + ; + +tableAlias +@init {gParent.pushMsg("table alias", state); } +@after {gParent.popMsg(state); } + : + identifier -> ^(TOK_TABALIAS identifier) + ; + +fromSource +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + (LPAREN KW_VALUES) => fromSource0 + | (LPAREN) => LPAREN joinSource RPAREN -> joinSource + | fromSource0 + ; + + +fromSource0 +@init { gParent.pushMsg("from source 0", state); } +@after { gParent.popMsg(state); } + : + ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* + ; + +tableBucketSample +@init { gParent.pushMsg("table bucket sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) + ; + +splitSample +@init { gParent.pushMsg("table split sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN + -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) + -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) + | + KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN + -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) + ; + +tableSample +@init { gParent.pushMsg("table sample specification", state); } +@after { gParent.popMsg(state); } + : + tableBucketSample | + splitSample + ; + +tableSource +@init { gParent.pushMsg("table source", state); } +@after { gParent.popMsg(state); } + : tabname=tableName + ((tableProperties) => props=tableProperties)? + ((tableSample) => ts=tableSample)? + ((KW_AS) => (KW_AS alias=Identifier) + | + (Identifier) => (alias=Identifier))? + -> ^(TOK_TABREF $tabname $props? $ts? $alias?) + ; + +tableName +@init { gParent.pushMsg("table name", state); } +@after { gParent.popMsg(state); } + : + db=identifier DOT tab=identifier + -> ^(TOK_TABNAME $db $tab) + | + tab=identifier + -> ^(TOK_TABNAME $tab) + ; + +viewName +@init { gParent.pushMsg("view name", state); } +@after { gParent.popMsg(state); } + : + (db=identifier DOT)? view=identifier + -> ^(TOK_TABNAME $db? $view) + ; + +subQuerySource +@init { gParent.pushMsg("subquery source", state); } +@after { gParent.popMsg(state); } + : + LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) + ; + +//---------------------- Rules for parsing PTF clauses ----------------------------- +partitioningSpec +@init { gParent.pushMsg("partitioningSpec clause", state); } +@after { gParent.popMsg(state); } + : + partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | + orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | + distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | + sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | + clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) + ; + +partitionTableFunctionSource +@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } +@after { gParent.popMsg(state); } + : + subQuerySource | + tableSource | + partitionedTableFunction + ; + +partitionedTableFunction +@init { gParent.pushMsg("ptf clause", state); } +@after { gParent.popMsg(state); } + : + name=Identifier LPAREN KW_ON + ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) + ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? + ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? + -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) + ; + +//----------------------- Rules for parsing whereClause ----------------------------- +// where a=b and ... +whereClause +@init { gParent.pushMsg("where clause", state); } +@after { gParent.popMsg(state); } + : + KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) + ; + +searchCondition +@init { gParent.pushMsg("search condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +//----------------------------------------------------------------------------------- + +//-------- Row Constructor ---------------------------------------------------------- +//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and +// INSERT INTO (col1,col2,...) VALUES(...),(...),... +// INSERT INTO
    (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) +valueRowConstructor +@init { gParent.pushMsg("value row constructor", state); } +@after { gParent.popMsg(state); } + : + LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) + ; + +valuesTableConstructor +@init { gParent.pushMsg("values table constructor", state); } +@after { gParent.popMsg(state); } + : + valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) + ; + +/* +VALUES(1),(2) means 2 rows, 1 column each. +VALUES(1,2),(3,4) means 2 rows, 2 columns each. +VALUES(1,2,3) means 1 row, 3 columns +*/ +valuesClause +@init { gParent.pushMsg("values clause", state); } +@after { gParent.popMsg(state); } + : + KW_VALUES valuesTableConstructor -> valuesTableConstructor + ; + +/* +This represents a clause like this: +(VALUES(1,2),(2,3)) as VirtTable(col1,col2) +*/ +virtualTableSource +@init { gParent.pushMsg("virtual table source", state); } +@after { gParent.popMsg(state); } + : + LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) + ; +/* +e.g. as VirtTable(col1,col2) +Note that we only want literals as column names +*/ +tableNameColList +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) + ; + +//----------------------------------------------------------------------------------- diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g new file mode 100644 index 000000000000..5c3d7ef86624 --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g @@ -0,0 +1,697 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar IdentifiersParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +// group by a,b +groupByClause +@init { gParent.pushMsg("group by clause", state); } +@after { gParent.popMsg(state); } + : + KW_GROUP KW_BY + expression + ( COMMA expression)* + ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? + (sets=KW_GROUPING KW_SETS + LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? + -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) + -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) + -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) + -> ^(TOK_GROUPBY expression+) + ; + +groupingSetExpression +@init {gParent.pushMsg("grouping set expression", state); } +@after {gParent.popMsg(state); } + : + (LPAREN) => groupingSetExpressionMultiple + | + groupingExpressionSingle + ; + +groupingSetExpressionMultiple +@init {gParent.pushMsg("grouping set part expression", state); } +@after {gParent.popMsg(state); } + : + LPAREN + expression? (COMMA expression)* + RPAREN + -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) + ; + +groupingExpressionSingle +@init { gParent.pushMsg("groupingExpression expression", state); } +@after { gParent.popMsg(state); } + : + expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) + ; + +havingClause +@init { gParent.pushMsg("having clause", state); } +@after { gParent.popMsg(state); } + : + KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) + ; + +havingCondition +@init { gParent.pushMsg("having condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +expressionsInParenthese + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +expressionsNotInParenthese + : + expression (COMMA expression)* -> expression+ + ; + +columnRefOrderInParenthese + : + LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ + ; + +columnRefOrderNotInParenthese + : + columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ + ; + +// order by a,b +orderByClause +@init { gParent.pushMsg("order by clause", state); } +@after { gParent.popMsg(state); } + : + KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) + ; + +clusterByClause +@init { gParent.pushMsg("cluster by clause", state); } +@after { gParent.popMsg(state); } + : + KW_CLUSTER KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) + ) + ; + +partitionByClause +@init { gParent.pushMsg("partition by clause", state); } +@after { gParent.popMsg(state); } + : + KW_PARTITION KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +distributeByClause +@init { gParent.pushMsg("distribute by clause", state); } +@after { gParent.popMsg(state); } + : + KW_DISTRIBUTE KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +sortByClause +@init { gParent.pushMsg("sort by clause", state); } +@after { gParent.popMsg(state); } + : + KW_SORT KW_BY + ( + (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) + | + columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) + ) + ; + +// fun(par1, par2, par3) +function +@init { gParent.pushMsg("function specification", state); } +@after { gParent.popMsg(state); } + : + functionName + LPAREN + ( + (STAR) => (star=STAR) + | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? + ) + RPAREN (KW_OVER ws=window_specification)? + -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) + -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) + -> ^(TOK_FUNCTIONDI functionName (selectExpression+)?) + ; + +functionName +@init { gParent.pushMsg("function name", state); } +@after { gParent.popMsg(state); } + : // Keyword IF is also a function name + (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) + | + (functionIdentifier) => functionIdentifier + | + {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] + ; + +castExpression +@init { gParent.pushMsg("cast expression", state); } +@after { gParent.popMsg(state); } + : + KW_CAST + LPAREN + expression + KW_AS + primitiveType + RPAREN -> ^(TOK_FUNCTION primitiveType expression) + ; + +caseExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE expression + (KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_CASE expression*) + ; + +whenExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE + ( KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) + ; + +constant +@init { gParent.pushMsg("constant", state); } +@after { gParent.popMsg(state); } + : + Number + | dateLiteral + | timestampLiteral + | intervalLiteral + | StringLiteral + | stringLiteralSequence + | BigintLiteral + | SmallintLiteral + | TinyintLiteral + | DecimalLiteral + | charSetStringLiteral + | booleanValue + ; + +stringLiteralSequence + : + StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) + ; + +charSetStringLiteral +@init { gParent.pushMsg("character string literal", state); } +@after { gParent.popMsg(state); } + : + csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) + ; + +dateLiteral + : + KW_DATE StringLiteral -> + { + // Create DateLiteral token, but with the text of the string value + // This makes the dateLiteral more consistent with the other type literals. + adaptor.create(TOK_DATELITERAL, $StringLiteral.text) + } + | + KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) + ; + +timestampLiteral + : + KW_TIMESTAMP StringLiteral -> + { + adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) + } + | + KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) + ; + +intervalLiteral + : + KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> + { + adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + } + ; + +intervalQualifiers + : + KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL + | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL + | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL + | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL + | KW_DAY -> TOK_INTERVAL_DAY_LITERAL + | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL + | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL + | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + ; + +expression +@init { gParent.pushMsg("expression specification", state); } +@after { gParent.popMsg(state); } + : + precedenceOrExpression + ; + +atomExpression + : + (KW_NULL) => KW_NULL -> TOK_NULL + | (constant) => constant + | castExpression + | caseExpression + | whenExpression + | (functionName LPAREN) => function + | tableOrColumn + | LPAREN! expression RPAREN! + ; + + +precedenceFieldExpression + : + atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* + ; + +precedenceUnaryOperator + : + PLUS | MINUS | TILDE + ; + +nullCondition + : + KW_NULL -> ^(TOK_ISNULL) + | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) + ; + +precedenceUnaryPrefixExpression + : + (precedenceUnaryOperator^)* precedenceFieldExpression + ; + +precedenceUnarySuffixExpression + : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) + -> precedenceUnaryPrefixExpression + ; + + +precedenceBitwiseXorOperator + : + BITWISEXOR + ; + +precedenceBitwiseXorExpression + : + precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* + ; + + +precedenceStarOperator + : + STAR | DIVIDE | MOD | DIV + ; + +precedenceStarExpression + : + precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* + ; + + +precedencePlusOperator + : + PLUS | MINUS + ; + +precedencePlusExpression + : + precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* + ; + + +precedenceAmpersandOperator + : + AMPERSAND + ; + +precedenceAmpersandExpression + : + precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* + ; + + +precedenceBitwiseOrOperator + : + BITWISEOR + ; + +precedenceBitwiseOrExpression + : + precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* + ; + + +// Equal operators supporting NOT prefix +precedenceEqualNegatableOperator + : + KW_LIKE | KW_RLIKE | KW_REGEXP + ; + +precedenceEqualOperator + : + precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +subQueryExpression + : + LPAREN! selectStatement[true] RPAREN! + ; + +precedenceEqualExpression + : + (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple + | + precedenceEqualExpressionSingle + ; + +precedenceEqualExpressionSingle + : + (left=precedenceBitwiseOrExpression -> $left) + ( + (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) + -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) + | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) + -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) + | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) + -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) + | (KW_NOT KW_IN expressions) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) + | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) + | (KW_IN expressions) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) + | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) + | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) + )* + | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) + ; + +expressions + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) +precedenceEqualExpressionMutiple + : + (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) + ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) + | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) + ; + +expressionsToStruct + : + LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) + ; + +precedenceNotOperator + : + KW_NOT + ; + +precedenceNotExpression + : + (precedenceNotOperator^)* precedenceEqualExpression + ; + + +precedenceAndOperator + : + KW_AND + ; + +precedenceAndExpression + : + precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* + ; + + +precedenceOrOperator + : + KW_OR + ; + +precedenceOrExpression + : + precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* + ; + + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : db=identifier DOT fn=identifier + -> Identifier[$db.text + "." + $fn.text] + | + identifier + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g new file mode 100644 index 000000000000..48bc8b0a300a --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g @@ -0,0 +1,226 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar SelectClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------- Rules for parsing selectClause ----------------------------- +// select a,b,c ... +selectClause +@init { gParent.pushMsg("select clause", state); } +@after { gParent.popMsg(state); } + : + KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) + | (transform=KW_TRANSFORM selectTrfmClause)) + -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) + -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) + -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) + | + trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) + ; + +selectList +@init { gParent.pushMsg("select list", state); } +@after { gParent.popMsg(state); } + : + selectItem ( COMMA selectItem )* -> selectItem+ + ; + +selectTrfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + LPAREN selectExpressionList RPAREN + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +hintClause +@init { gParent.pushMsg("hint clause", state); } +@after { gParent.popMsg(state); } + : + DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) + ; + +hintList +@init { gParent.pushMsg("hint list", state); } +@after { gParent.popMsg(state); } + : + hintItem (COMMA hintItem)* -> hintItem+ + ; + +hintItem +@init { gParent.pushMsg("hint item", state); } +@after { gParent.popMsg(state); } + : + hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) + ; + +hintName +@init { gParent.pushMsg("hint name", state); } +@after { gParent.popMsg(state); } + : + KW_MAPJOIN -> TOK_MAPJOIN + | KW_STREAMTABLE -> TOK_STREAMTABLE + ; + +hintArgs +@init { gParent.pushMsg("hint arguments", state); } +@after { gParent.popMsg(state); } + : + hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) + ; + +hintArgName +@init { gParent.pushMsg("hint argument name", state); } +@after { gParent.popMsg(state); } + : + identifier + ; + +selectItem +@init { gParent.pushMsg("selection target", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) + | + ( expression + ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? + ) -> ^(TOK_SELEXPR expression identifier*) + ; + +trfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + ( KW_MAP selectExpressionList + | KW_REDUCE selectExpressionList ) + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +selectExpression +@init { gParent.pushMsg("select expression", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns + | + expression + ; + +selectExpressionList +@init { gParent.pushMsg("select expression list", state); } +@after { gParent.popMsg(state); } + : + selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) + ; + +//---------------------- Rules for windowing clauses ------------------------------- +window_clause +@init { gParent.pushMsg("window_clause", state); } +@after { gParent.popMsg(state); } +: + KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) +; + +window_defn +@init { gParent.pushMsg("window_defn", state); } +@after { gParent.popMsg(state); } +: + Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) +; + +window_specification +@init { gParent.pushMsg("window_specification", state); } +@after { gParent.popMsg(state); } +: + (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) +; + +window_frame : + window_range_expression | + window_value_expression +; + +window_range_expression +@init { gParent.pushMsg("window_range_expression", state); } +@after { gParent.popMsg(state); } +: + KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | + KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) +; + +window_value_expression +@init { gParent.pushMsg("window_value_expression", state); } +@after { gParent.popMsg(state); } +: + KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | + KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) +; + +window_frame_start_boundary +@init { gParent.pushMsg("windowframestartboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number KW_PRECEDING -> ^(KW_PRECEDING Number) +; + +window_frame_boundary +@init { gParent.pushMsg("windowframeboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) +; + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g new file mode 100644 index 000000000000..ee1b8989b5af --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g @@ -0,0 +1,474 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +lexer grammar SparkSqlLexer; + +@lexer::header { +package org.apache.spark.sql.parser; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +} + +@lexer::members { + private Configuration hiveConf; + + public void setHiveConf(Configuration hiveConf) { + this.hiveConf = hiveConf; + } + + protected boolean allowQuotedId() { + String supportedQIds = HiveConf.getVar(hiveConf, HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT); + return !"none".equals(supportedQIds); + } +} + +// Keywords + +KW_TRUE : 'TRUE'; +KW_FALSE : 'FALSE'; +KW_ALL : 'ALL'; +KW_NONE: 'NONE'; +KW_AND : 'AND'; +KW_OR : 'OR'; +KW_NOT : 'NOT' | '!'; +KW_LIKE : 'LIKE'; + +KW_IF : 'IF'; +KW_EXISTS : 'EXISTS'; + +KW_ASC : 'ASC'; +KW_DESC : 'DESC'; +KW_ORDER : 'ORDER'; +KW_GROUP : 'GROUP'; +KW_BY : 'BY'; +KW_HAVING : 'HAVING'; +KW_WHERE : 'WHERE'; +KW_FROM : 'FROM'; +KW_AS : 'AS'; +KW_SELECT : 'SELECT'; +KW_DISTINCT : 'DISTINCT'; +KW_INSERT : 'INSERT'; +KW_OVERWRITE : 'OVERWRITE'; +KW_OUTER : 'OUTER'; +KW_UNIQUEJOIN : 'UNIQUEJOIN'; +KW_PRESERVE : 'PRESERVE'; +KW_JOIN : 'JOIN'; +KW_LEFT : 'LEFT'; +KW_RIGHT : 'RIGHT'; +KW_FULL : 'FULL'; +KW_ANTI : 'ANTI'; +KW_ON : 'ON'; +KW_PARTITION : 'PARTITION'; +KW_PARTITIONS : 'PARTITIONS'; +KW_TABLE: 'TABLE'; +KW_TABLES: 'TABLES'; +KW_COLUMNS: 'COLUMNS'; +KW_INDEX: 'INDEX'; +KW_INDEXES: 'INDEXES'; +KW_REBUILD: 'REBUILD'; +KW_FUNCTIONS: 'FUNCTIONS'; +KW_SHOW: 'SHOW'; +KW_MSCK: 'MSCK'; +KW_REPAIR: 'REPAIR'; +KW_DIRECTORY: 'DIRECTORY'; +KW_LOCAL: 'LOCAL'; +KW_TRANSFORM : 'TRANSFORM'; +KW_USING: 'USING'; +KW_CLUSTER: 'CLUSTER'; +KW_DISTRIBUTE: 'DISTRIBUTE'; +KW_SORT: 'SORT'; +KW_UNION: 'UNION'; +KW_LOAD: 'LOAD'; +KW_EXPORT: 'EXPORT'; +KW_IMPORT: 'IMPORT'; +KW_REPLICATION: 'REPLICATION'; +KW_METADATA: 'METADATA'; +KW_DATA: 'DATA'; +KW_INPATH: 'INPATH'; +KW_IS: 'IS'; +KW_NULL: 'NULL'; +KW_CREATE: 'CREATE'; +KW_EXTERNAL: 'EXTERNAL'; +KW_ALTER: 'ALTER'; +KW_CHANGE: 'CHANGE'; +KW_COLUMN: 'COLUMN'; +KW_FIRST: 'FIRST'; +KW_AFTER: 'AFTER'; +KW_DESCRIBE: 'DESCRIBE'; +KW_DROP: 'DROP'; +KW_RENAME: 'RENAME'; +KW_TO: 'TO'; +KW_COMMENT: 'COMMENT'; +KW_BOOLEAN: 'BOOLEAN'; +KW_TINYINT: 'TINYINT'; +KW_SMALLINT: 'SMALLINT'; +KW_INT: 'INT'; +KW_BIGINT: 'BIGINT'; +KW_FLOAT: 'FLOAT'; +KW_DOUBLE: 'DOUBLE'; +KW_DATE: 'DATE'; +KW_DATETIME: 'DATETIME'; +KW_TIMESTAMP: 'TIMESTAMP'; +KW_INTERVAL: 'INTERVAL'; +KW_DECIMAL: 'DECIMAL'; +KW_STRING: 'STRING'; +KW_CHAR: 'CHAR'; +KW_VARCHAR: 'VARCHAR'; +KW_ARRAY: 'ARRAY'; +KW_STRUCT: 'STRUCT'; +KW_MAP: 'MAP'; +KW_UNIONTYPE: 'UNIONTYPE'; +KW_REDUCE: 'REDUCE'; +KW_PARTITIONED: 'PARTITIONED'; +KW_CLUSTERED: 'CLUSTERED'; +KW_SORTED: 'SORTED'; +KW_INTO: 'INTO'; +KW_BUCKETS: 'BUCKETS'; +KW_ROW: 'ROW'; +KW_ROWS: 'ROWS'; +KW_FORMAT: 'FORMAT'; +KW_DELIMITED: 'DELIMITED'; +KW_FIELDS: 'FIELDS'; +KW_TERMINATED: 'TERMINATED'; +KW_ESCAPED: 'ESCAPED'; +KW_COLLECTION: 'COLLECTION'; +KW_ITEMS: 'ITEMS'; +KW_KEYS: 'KEYS'; +KW_KEY_TYPE: '$KEY$'; +KW_LINES: 'LINES'; +KW_STORED: 'STORED'; +KW_FILEFORMAT: 'FILEFORMAT'; +KW_INPUTFORMAT: 'INPUTFORMAT'; +KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; +KW_INPUTDRIVER: 'INPUTDRIVER'; +KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; +KW_ENABLE: 'ENABLE'; +KW_DISABLE: 'DISABLE'; +KW_LOCATION: 'LOCATION'; +KW_TABLESAMPLE: 'TABLESAMPLE'; +KW_BUCKET: 'BUCKET'; +KW_OUT: 'OUT'; +KW_OF: 'OF'; +KW_PERCENT: 'PERCENT'; +KW_CAST: 'CAST'; +KW_ADD: 'ADD'; +KW_REPLACE: 'REPLACE'; +KW_RLIKE: 'RLIKE'; +KW_REGEXP: 'REGEXP'; +KW_TEMPORARY: 'TEMPORARY'; +KW_FUNCTION: 'FUNCTION'; +KW_MACRO: 'MACRO'; +KW_FILE: 'FILE'; +KW_JAR: 'JAR'; +KW_EXPLAIN: 'EXPLAIN'; +KW_EXTENDED: 'EXTENDED'; +KW_FORMATTED: 'FORMATTED'; +KW_PRETTY: 'PRETTY'; +KW_DEPENDENCY: 'DEPENDENCY'; +KW_LOGICAL: 'LOGICAL'; +KW_SERDE: 'SERDE'; +KW_WITH: 'WITH'; +KW_DEFERRED: 'DEFERRED'; +KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; +KW_DBPROPERTIES: 'DBPROPERTIES'; +KW_LIMIT: 'LIMIT'; +KW_SET: 'SET'; +KW_UNSET: 'UNSET'; +KW_TBLPROPERTIES: 'TBLPROPERTIES'; +KW_IDXPROPERTIES: 'IDXPROPERTIES'; +KW_VALUE_TYPE: '$VALUE$'; +KW_ELEM_TYPE: '$ELEM$'; +KW_DEFINED: 'DEFINED'; +KW_CASE: 'CASE'; +KW_WHEN: 'WHEN'; +KW_THEN: 'THEN'; +KW_ELSE: 'ELSE'; +KW_END: 'END'; +KW_MAPJOIN: 'MAPJOIN'; +KW_STREAMTABLE: 'STREAMTABLE'; +KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; +KW_UTC: 'UTC'; +KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; +KW_LONG: 'LONG'; +KW_DELETE: 'DELETE'; +KW_PLUS: 'PLUS'; +KW_MINUS: 'MINUS'; +KW_FETCH: 'FETCH'; +KW_INTERSECT: 'INTERSECT'; +KW_VIEW: 'VIEW'; +KW_IN: 'IN'; +KW_DATABASE: 'DATABASE'; +KW_DATABASES: 'DATABASES'; +KW_MATERIALIZED: 'MATERIALIZED'; +KW_SCHEMA: 'SCHEMA'; +KW_SCHEMAS: 'SCHEMAS'; +KW_GRANT: 'GRANT'; +KW_REVOKE: 'REVOKE'; +KW_SSL: 'SSL'; +KW_UNDO: 'UNDO'; +KW_LOCK: 'LOCK'; +KW_LOCKS: 'LOCKS'; +KW_UNLOCK: 'UNLOCK'; +KW_SHARED: 'SHARED'; +KW_EXCLUSIVE: 'EXCLUSIVE'; +KW_PROCEDURE: 'PROCEDURE'; +KW_UNSIGNED: 'UNSIGNED'; +KW_WHILE: 'WHILE'; +KW_READ: 'READ'; +KW_READS: 'READS'; +KW_PURGE: 'PURGE'; +KW_RANGE: 'RANGE'; +KW_ANALYZE: 'ANALYZE'; +KW_BEFORE: 'BEFORE'; +KW_BETWEEN: 'BETWEEN'; +KW_BOTH: 'BOTH'; +KW_BINARY: 'BINARY'; +KW_CROSS: 'CROSS'; +KW_CONTINUE: 'CONTINUE'; +KW_CURSOR: 'CURSOR'; +KW_TRIGGER: 'TRIGGER'; +KW_RECORDREADER: 'RECORDREADER'; +KW_RECORDWRITER: 'RECORDWRITER'; +KW_SEMI: 'SEMI'; +KW_LATERAL: 'LATERAL'; +KW_TOUCH: 'TOUCH'; +KW_ARCHIVE: 'ARCHIVE'; +KW_UNARCHIVE: 'UNARCHIVE'; +KW_COMPUTE: 'COMPUTE'; +KW_STATISTICS: 'STATISTICS'; +KW_USE: 'USE'; +KW_OPTION: 'OPTION'; +KW_CONCATENATE: 'CONCATENATE'; +KW_SHOW_DATABASE: 'SHOW_DATABASE'; +KW_UPDATE: 'UPDATE'; +KW_RESTRICT: 'RESTRICT'; +KW_CASCADE: 'CASCADE'; +KW_SKEWED: 'SKEWED'; +KW_ROLLUP: 'ROLLUP'; +KW_CUBE: 'CUBE'; +KW_DIRECTORIES: 'DIRECTORIES'; +KW_FOR: 'FOR'; +KW_WINDOW: 'WINDOW'; +KW_UNBOUNDED: 'UNBOUNDED'; +KW_PRECEDING: 'PRECEDING'; +KW_FOLLOWING: 'FOLLOWING'; +KW_CURRENT: 'CURRENT'; +KW_CURRENT_DATE: 'CURRENT_DATE'; +KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +KW_LESS: 'LESS'; +KW_MORE: 'MORE'; +KW_OVER: 'OVER'; +KW_GROUPING: 'GROUPING'; +KW_SETS: 'SETS'; +KW_TRUNCATE: 'TRUNCATE'; +KW_NOSCAN: 'NOSCAN'; +KW_PARTIALSCAN: 'PARTIALSCAN'; +KW_USER: 'USER'; +KW_ROLE: 'ROLE'; +KW_ROLES: 'ROLES'; +KW_INNER: 'INNER'; +KW_EXCHANGE: 'EXCHANGE'; +KW_URI: 'URI'; +KW_SERVER : 'SERVER'; +KW_ADMIN: 'ADMIN'; +KW_OWNER: 'OWNER'; +KW_PRINCIPALS: 'PRINCIPALS'; +KW_COMPACT: 'COMPACT'; +KW_COMPACTIONS: 'COMPACTIONS'; +KW_TRANSACTIONS: 'TRANSACTIONS'; +KW_REWRITE : 'REWRITE'; +KW_AUTHORIZATION: 'AUTHORIZATION'; +KW_CONF: 'CONF'; +KW_VALUES: 'VALUES'; +KW_RELOAD: 'RELOAD'; +KW_YEAR: 'YEAR'; +KW_MONTH: 'MONTH'; +KW_DAY: 'DAY'; +KW_HOUR: 'HOUR'; +KW_MINUTE: 'MINUTE'; +KW_SECOND: 'SECOND'; +KW_START: 'START'; +KW_TRANSACTION: 'TRANSACTION'; +KW_COMMIT: 'COMMIT'; +KW_ROLLBACK: 'ROLLBACK'; +KW_WORK: 'WORK'; +KW_ONLY: 'ONLY'; +KW_WRITE: 'WRITE'; +KW_ISOLATION: 'ISOLATION'; +KW_LEVEL: 'LEVEL'; +KW_SNAPSHOT: 'SNAPSHOT'; +KW_AUTOCOMMIT: 'AUTOCOMMIT'; + +// Operators +// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. + +DOT : '.'; // generated as a part of Number rule +COLON : ':' ; +COMMA : ',' ; +SEMICOLON : ';' ; + +LPAREN : '(' ; +RPAREN : ')' ; +LSQUARE : '[' ; +RSQUARE : ']' ; +LCURLY : '{'; +RCURLY : '}'; + +EQUAL : '=' | '=='; +EQUAL_NS : '<=>'; +NOTEQUAL : '<>' | '!='; +LESSTHANOREQUALTO : '<='; +LESSTHAN : '<'; +GREATERTHANOREQUALTO : '>='; +GREATERTHAN : '>'; + +DIVIDE : '/'; +PLUS : '+'; +MINUS : '-'; +STAR : '*'; +MOD : '%'; +DIV : 'DIV'; + +AMPERSAND : '&'; +TILDE : '~'; +BITWISEOR : '|'; +BITWISEXOR : '^'; +QUESTION : '?'; +DOLLAR : '$'; + +// LITERALS +fragment +Letter + : 'a'..'z' | 'A'..'Z' + ; + +fragment +HexDigit + : 'a'..'f' | 'A'..'F' + ; + +fragment +Digit + : + '0'..'9' + ; + +fragment +Exponent + : + ('e' | 'E') ( PLUS|MINUS )? (Digit)+ + ; + +fragment +RegexComponent + : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' + | PLUS | STAR | QUESTION | MINUS | DOT + | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY + | BITWISEXOR | BITWISEOR | DOLLAR | '!' + ; + +StringLiteral + : + ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + )+ + ; + +CharSetLiteral + : + StringLiteral + | '0' 'X' (HexDigit|Digit)+ + ; + +BigintLiteral + : + (Digit)+ 'L' + ; + +SmallintLiteral + : + (Digit)+ 'S' + ; + +TinyintLiteral + : + (Digit)+ 'Y' + ; + +DecimalLiteral + : + Number 'B' 'D' + ; + +ByteLengthLiteral + : + (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') + ; + +Number + : + (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? + ; + +/* +An Identifier can be: +- tableName +- columnName +- select expr alias +- lateral view aliases +- database name +- view name +- subquery alias +- function name +- ptf argument identifier +- index name +- property name for: db,tbl,partition... +- fileFormat +- role name +- privilege name +- principal name +- macro name +- hint name +- window name +*/ +Identifier + : + (Letter | Digit) (Letter | Digit | '_')* + | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; + at the API level only columns are allowed to be of this form */ + | '`' RegexComponent+ '`' + ; + +fragment +QuotedIdentifier + : + '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + ; + +CharSetName + : + '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ + ; + +WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} + ; + +COMMENT + : '--' (~('\n'|'\r'))* + { $channel=HIDDEN; } + ; + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g new file mode 100644 index 000000000000..69574d713d0b --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g @@ -0,0 +1,2457 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar SparkSqlParser; + +options +{ +tokenVocab=SparkSqlLexer; +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} +import SelectClauseParser, FromClauseParser, IdentifiersParser; + +tokens { +TOK_INSERT; +TOK_QUERY; +TOK_SELECT; +TOK_SELECTDI; +TOK_SELEXPR; +TOK_FROM; +TOK_TAB; +TOK_PARTSPEC; +TOK_PARTVAL; +TOK_DIR; +TOK_TABREF; +TOK_SUBQUERY; +TOK_INSERT_INTO; +TOK_DESTINATION; +TOK_ALLCOLREF; +TOK_TABLE_OR_COL; +TOK_FUNCTION; +TOK_FUNCTIONDI; +TOK_FUNCTIONSTAR; +TOK_WHERE; +TOK_OP_EQ; +TOK_OP_NE; +TOK_OP_LE; +TOK_OP_LT; +TOK_OP_GE; +TOK_OP_GT; +TOK_OP_DIV; +TOK_OP_ADD; +TOK_OP_SUB; +TOK_OP_MUL; +TOK_OP_MOD; +TOK_OP_BITAND; +TOK_OP_BITNOT; +TOK_OP_BITOR; +TOK_OP_BITXOR; +TOK_OP_AND; +TOK_OP_OR; +TOK_OP_NOT; +TOK_OP_LIKE; +TOK_TRUE; +TOK_FALSE; +TOK_TRANSFORM; +TOK_SERDE; +TOK_SERDENAME; +TOK_SERDEPROPS; +TOK_EXPLIST; +TOK_ALIASLIST; +TOK_GROUPBY; +TOK_ROLLUP_GROUPBY; +TOK_CUBE_GROUPBY; +TOK_GROUPING_SETS; +TOK_GROUPING_SETS_EXPRESSION; +TOK_HAVING; +TOK_ORDERBY; +TOK_CLUSTERBY; +TOK_DISTRIBUTEBY; +TOK_SORTBY; +TOK_UNIONALL; +TOK_UNIONDISTINCT; +TOK_JOIN; +TOK_LEFTOUTERJOIN; +TOK_RIGHTOUTERJOIN; +TOK_FULLOUTERJOIN; +TOK_UNIQUEJOIN; +TOK_CROSSJOIN; +TOK_LOAD; +TOK_EXPORT; +TOK_IMPORT; +TOK_REPLICATION; +TOK_METADATA; +TOK_NULL; +TOK_ISNULL; +TOK_ISNOTNULL; +TOK_TINYINT; +TOK_SMALLINT; +TOK_INT; +TOK_BIGINT; +TOK_BOOLEAN; +TOK_FLOAT; +TOK_DOUBLE; +TOK_DATE; +TOK_DATELITERAL; +TOK_DATETIME; +TOK_TIMESTAMP; +TOK_TIMESTAMPLITERAL; +TOK_INTERVAL_YEAR_MONTH; +TOK_INTERVAL_YEAR_MONTH_LITERAL; +TOK_INTERVAL_DAY_TIME; +TOK_INTERVAL_DAY_TIME_LITERAL; +TOK_INTERVAL_YEAR_LITERAL; +TOK_INTERVAL_MONTH_LITERAL; +TOK_INTERVAL_DAY_LITERAL; +TOK_INTERVAL_HOUR_LITERAL; +TOK_INTERVAL_MINUTE_LITERAL; +TOK_INTERVAL_SECOND_LITERAL; +TOK_STRING; +TOK_CHAR; +TOK_VARCHAR; +TOK_BINARY; +TOK_DECIMAL; +TOK_LIST; +TOK_STRUCT; +TOK_MAP; +TOK_UNIONTYPE; +TOK_COLTYPELIST; +TOK_CREATEDATABASE; +TOK_CREATETABLE; +TOK_TRUNCATETABLE; +TOK_CREATEINDEX; +TOK_CREATEINDEX_INDEXTBLNAME; +TOK_DEFERRED_REBUILDINDEX; +TOK_DROPINDEX; +TOK_LIKETABLE; +TOK_DESCTABLE; +TOK_DESCFUNCTION; +TOK_ALTERTABLE; +TOK_ALTERTABLE_RENAME; +TOK_ALTERTABLE_ADDCOLS; +TOK_ALTERTABLE_RENAMECOL; +TOK_ALTERTABLE_RENAMEPART; +TOK_ALTERTABLE_REPLACECOLS; +TOK_ALTERTABLE_ADDPARTS; +TOK_ALTERTABLE_DROPPARTS; +TOK_ALTERTABLE_PARTCOLTYPE; +TOK_ALTERTABLE_MERGEFILES; +TOK_ALTERTABLE_TOUCH; +TOK_ALTERTABLE_ARCHIVE; +TOK_ALTERTABLE_UNARCHIVE; +TOK_ALTERTABLE_SERDEPROPERTIES; +TOK_ALTERTABLE_SERIALIZER; +TOK_ALTERTABLE_UPDATECOLSTATS; +TOK_TABLE_PARTITION; +TOK_ALTERTABLE_FILEFORMAT; +TOK_ALTERTABLE_LOCATION; +TOK_ALTERTABLE_PROPERTIES; +TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; +TOK_ALTERTABLE_DROPPROPERTIES; +TOK_ALTERTABLE_SKEWED; +TOK_ALTERTABLE_EXCHANGEPARTITION; +TOK_ALTERTABLE_SKEWED_LOCATION; +TOK_ALTERTABLE_BUCKETS; +TOK_ALTERTABLE_CLUSTER_SORT; +TOK_ALTERTABLE_COMPACT; +TOK_ALTERINDEX_REBUILD; +TOK_ALTERINDEX_PROPERTIES; +TOK_MSCK; +TOK_SHOWDATABASES; +TOK_SHOWTABLES; +TOK_SHOWCOLUMNS; +TOK_SHOWFUNCTIONS; +TOK_SHOWPARTITIONS; +TOK_SHOW_CREATEDATABASE; +TOK_SHOW_CREATETABLE; +TOK_SHOW_TABLESTATUS; +TOK_SHOW_TBLPROPERTIES; +TOK_SHOWLOCKS; +TOK_SHOWCONF; +TOK_LOCKTABLE; +TOK_UNLOCKTABLE; +TOK_LOCKDB; +TOK_UNLOCKDB; +TOK_SWITCHDATABASE; +TOK_DROPDATABASE; +TOK_DROPTABLE; +TOK_DATABASECOMMENT; +TOK_TABCOLLIST; +TOK_TABCOL; +TOK_TABLECOMMENT; +TOK_TABLEPARTCOLS; +TOK_TABLEROWFORMAT; +TOK_TABLEROWFORMATFIELD; +TOK_TABLEROWFORMATCOLLITEMS; +TOK_TABLEROWFORMATMAPKEYS; +TOK_TABLEROWFORMATLINES; +TOK_TABLEROWFORMATNULL; +TOK_TABLEFILEFORMAT; +TOK_FILEFORMAT_GENERIC; +TOK_OFFLINE; +TOK_ENABLE; +TOK_DISABLE; +TOK_READONLY; +TOK_NO_DROP; +TOK_STORAGEHANDLER; +TOK_NOT_CLUSTERED; +TOK_NOT_SORTED; +TOK_TABCOLNAME; +TOK_TABLELOCATION; +TOK_PARTITIONLOCATION; +TOK_TABLEBUCKETSAMPLE; +TOK_TABLESPLITSAMPLE; +TOK_PERCENT; +TOK_LENGTH; +TOK_ROWCOUNT; +TOK_TMP_FILE; +TOK_TABSORTCOLNAMEASC; +TOK_TABSORTCOLNAMEDESC; +TOK_STRINGLITERALSEQUENCE; +TOK_CHARSETLITERAL; +TOK_CREATEFUNCTION; +TOK_DROPFUNCTION; +TOK_RELOADFUNCTION; +TOK_CREATEMACRO; +TOK_DROPMACRO; +TOK_TEMPORARY; +TOK_CREATEVIEW; +TOK_DROPVIEW; +TOK_ALTERVIEW; +TOK_ALTERVIEW_PROPERTIES; +TOK_ALTERVIEW_DROPPROPERTIES; +TOK_ALTERVIEW_ADDPARTS; +TOK_ALTERVIEW_DROPPARTS; +TOK_ALTERVIEW_RENAME; +TOK_VIEWPARTCOLS; +TOK_EXPLAIN; +TOK_EXPLAIN_SQ_REWRITE; +TOK_TABLESERIALIZER; +TOK_TABLEPROPERTIES; +TOK_TABLEPROPLIST; +TOK_INDEXPROPERTIES; +TOK_INDEXPROPLIST; +TOK_TABTYPE; +TOK_LIMIT; +TOK_TABLEPROPERTY; +TOK_IFEXISTS; +TOK_IFNOTEXISTS; +TOK_ORREPLACE; +TOK_HINTLIST; +TOK_HINT; +TOK_MAPJOIN; +TOK_STREAMTABLE; +TOK_HINTARGLIST; +TOK_USERSCRIPTCOLNAMES; +TOK_USERSCRIPTCOLSCHEMA; +TOK_RECORDREADER; +TOK_RECORDWRITER; +TOK_LEFTSEMIJOIN; +TOK_ANTIJOIN; +TOK_LATERAL_VIEW; +TOK_LATERAL_VIEW_OUTER; +TOK_TABALIAS; +TOK_ANALYZE; +TOK_CREATEROLE; +TOK_DROPROLE; +TOK_GRANT; +TOK_REVOKE; +TOK_SHOW_GRANT; +TOK_PRIVILEGE_LIST; +TOK_PRIVILEGE; +TOK_PRINCIPAL_NAME; +TOK_USER; +TOK_GROUP; +TOK_ROLE; +TOK_RESOURCE_ALL; +TOK_GRANT_WITH_OPTION; +TOK_GRANT_WITH_ADMIN_OPTION; +TOK_ADMIN_OPTION_FOR; +TOK_GRANT_OPTION_FOR; +TOK_PRIV_ALL; +TOK_PRIV_ALTER_METADATA; +TOK_PRIV_ALTER_DATA; +TOK_PRIV_DELETE; +TOK_PRIV_DROP; +TOK_PRIV_INDEX; +TOK_PRIV_INSERT; +TOK_PRIV_LOCK; +TOK_PRIV_SELECT; +TOK_PRIV_SHOW_DATABASE; +TOK_PRIV_CREATE; +TOK_PRIV_OBJECT; +TOK_PRIV_OBJECT_COL; +TOK_GRANT_ROLE; +TOK_REVOKE_ROLE; +TOK_SHOW_ROLE_GRANT; +TOK_SHOW_ROLES; +TOK_SHOW_SET_ROLE; +TOK_SHOW_ROLE_PRINCIPALS; +TOK_SHOWINDEXES; +TOK_SHOWDBLOCKS; +TOK_INDEXCOMMENT; +TOK_DESCDATABASE; +TOK_DATABASEPROPERTIES; +TOK_DATABASELOCATION; +TOK_DBPROPLIST; +TOK_ALTERDATABASE_PROPERTIES; +TOK_ALTERDATABASE_OWNER; +TOK_TABNAME; +TOK_TABSRC; +TOK_RESTRICT; +TOK_CASCADE; +TOK_TABLESKEWED; +TOK_TABCOLVALUE; +TOK_TABCOLVALUE_PAIR; +TOK_TABCOLVALUES; +TOK_SKEWED_LOCATIONS; +TOK_SKEWED_LOCATION_LIST; +TOK_SKEWED_LOCATION_MAP; +TOK_STOREDASDIRS; +TOK_PARTITIONINGSPEC; +TOK_PTBLFUNCTION; +TOK_WINDOWDEF; +TOK_WINDOWSPEC; +TOK_WINDOWVALUES; +TOK_WINDOWRANGE; +TOK_SUBQUERY_EXPR; +TOK_SUBQUERY_OP; +TOK_SUBQUERY_OP_NOTIN; +TOK_SUBQUERY_OP_NOTEXISTS; +TOK_DB_TYPE; +TOK_TABLE_TYPE; +TOK_CTE; +TOK_ARCHIVE; +TOK_FILE; +TOK_JAR; +TOK_RESOURCE_URI; +TOK_RESOURCE_LIST; +TOK_SHOW_COMPACTIONS; +TOK_SHOW_TRANSACTIONS; +TOK_DELETE_FROM; +TOK_UPDATE_TABLE; +TOK_SET_COLUMNS_CLAUSE; +TOK_VALUE_ROW; +TOK_VALUES_TABLE; +TOK_VIRTUAL_TABLE; +TOK_VIRTUAL_TABREF; +TOK_ANONYMOUS; +TOK_COL_NAME; +TOK_URI_TYPE; +TOK_SERVER_TYPE; +TOK_START_TRANSACTION; +TOK_ISOLATION_LEVEL; +TOK_ISOLATION_SNAPSHOT; +TOK_TXN_ACCESS_MODE; +TOK_TXN_READ_ONLY; +TOK_TXN_READ_WRITE; +TOK_COMMIT; +TOK_ROLLBACK; +TOK_SET_AUTOCOMMIT; +} + + +// Package headers +@header { +package org.apache.spark.sql.parser; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +} + + +@members { + ArrayList errors = new ArrayList(); + Stack msgs = new Stack(); + + private static HashMap xlateMap; + static { + //this is used to support auto completion in CLI + xlateMap = new HashMap(); + + // Keywords + xlateMap.put("KW_TRUE", "TRUE"); + xlateMap.put("KW_FALSE", "FALSE"); + xlateMap.put("KW_ALL", "ALL"); + xlateMap.put("KW_NONE", "NONE"); + xlateMap.put("KW_AND", "AND"); + xlateMap.put("KW_OR", "OR"); + xlateMap.put("KW_NOT", "NOT"); + xlateMap.put("KW_LIKE", "LIKE"); + + xlateMap.put("KW_ASC", "ASC"); + xlateMap.put("KW_DESC", "DESC"); + xlateMap.put("KW_ORDER", "ORDER"); + xlateMap.put("KW_BY", "BY"); + xlateMap.put("KW_GROUP", "GROUP"); + xlateMap.put("KW_WHERE", "WHERE"); + xlateMap.put("KW_FROM", "FROM"); + xlateMap.put("KW_AS", "AS"); + xlateMap.put("KW_SELECT", "SELECT"); + xlateMap.put("KW_DISTINCT", "DISTINCT"); + xlateMap.put("KW_INSERT", "INSERT"); + xlateMap.put("KW_OVERWRITE", "OVERWRITE"); + xlateMap.put("KW_OUTER", "OUTER"); + xlateMap.put("KW_JOIN", "JOIN"); + xlateMap.put("KW_LEFT", "LEFT"); + xlateMap.put("KW_RIGHT", "RIGHT"); + xlateMap.put("KW_FULL", "FULL"); + xlateMap.put("KW_ON", "ON"); + xlateMap.put("KW_PARTITION", "PARTITION"); + xlateMap.put("KW_PARTITIONS", "PARTITIONS"); + xlateMap.put("KW_TABLE", "TABLE"); + xlateMap.put("KW_TABLES", "TABLES"); + xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_SHOW", "SHOW"); + xlateMap.put("KW_MSCK", "MSCK"); + xlateMap.put("KW_DIRECTORY", "DIRECTORY"); + xlateMap.put("KW_LOCAL", "LOCAL"); + xlateMap.put("KW_TRANSFORM", "TRANSFORM"); + xlateMap.put("KW_USING", "USING"); + xlateMap.put("KW_CLUSTER", "CLUSTER"); + xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); + xlateMap.put("KW_SORT", "SORT"); + xlateMap.put("KW_UNION", "UNION"); + xlateMap.put("KW_LOAD", "LOAD"); + xlateMap.put("KW_DATA", "DATA"); + xlateMap.put("KW_INPATH", "INPATH"); + xlateMap.put("KW_IS", "IS"); + xlateMap.put("KW_NULL", "NULL"); + xlateMap.put("KW_CREATE", "CREATE"); + xlateMap.put("KW_EXTERNAL", "EXTERNAL"); + xlateMap.put("KW_ALTER", "ALTER"); + xlateMap.put("KW_DESCRIBE", "DESCRIBE"); + xlateMap.put("KW_DROP", "DROP"); + xlateMap.put("KW_RENAME", "RENAME"); + xlateMap.put("KW_TO", "TO"); + xlateMap.put("KW_COMMENT", "COMMENT"); + xlateMap.put("KW_BOOLEAN", "BOOLEAN"); + xlateMap.put("KW_TINYINT", "TINYINT"); + xlateMap.put("KW_SMALLINT", "SMALLINT"); + xlateMap.put("KW_INT", "INT"); + xlateMap.put("KW_BIGINT", "BIGINT"); + xlateMap.put("KW_FLOAT", "FLOAT"); + xlateMap.put("KW_DOUBLE", "DOUBLE"); + xlateMap.put("KW_DATE", "DATE"); + xlateMap.put("KW_DATETIME", "DATETIME"); + xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); + xlateMap.put("KW_STRING", "STRING"); + xlateMap.put("KW_BINARY", "BINARY"); + xlateMap.put("KW_ARRAY", "ARRAY"); + xlateMap.put("KW_MAP", "MAP"); + xlateMap.put("KW_REDUCE", "REDUCE"); + xlateMap.put("KW_PARTITIONED", "PARTITIONED"); + xlateMap.put("KW_CLUSTERED", "CLUSTERED"); + xlateMap.put("KW_SORTED", "SORTED"); + xlateMap.put("KW_INTO", "INTO"); + xlateMap.put("KW_BUCKETS", "BUCKETS"); + xlateMap.put("KW_ROW", "ROW"); + xlateMap.put("KW_FORMAT", "FORMAT"); + xlateMap.put("KW_DELIMITED", "DELIMITED"); + xlateMap.put("KW_FIELDS", "FIELDS"); + xlateMap.put("KW_TERMINATED", "TERMINATED"); + xlateMap.put("KW_COLLECTION", "COLLECTION"); + xlateMap.put("KW_ITEMS", "ITEMS"); + xlateMap.put("KW_KEYS", "KEYS"); + xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); + xlateMap.put("KW_LINES", "LINES"); + xlateMap.put("KW_STORED", "STORED"); + xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); + xlateMap.put("KW_TEXTFILE", "TEXTFILE"); + xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); + xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); + xlateMap.put("KW_LOCATION", "LOCATION"); + xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); + xlateMap.put("KW_BUCKET", "BUCKET"); + xlateMap.put("KW_OUT", "OUT"); + xlateMap.put("KW_OF", "OF"); + xlateMap.put("KW_CAST", "CAST"); + xlateMap.put("KW_ADD", "ADD"); + xlateMap.put("KW_REPLACE", "REPLACE"); + xlateMap.put("KW_COLUMNS", "COLUMNS"); + xlateMap.put("KW_RLIKE", "RLIKE"); + xlateMap.put("KW_REGEXP", "REGEXP"); + xlateMap.put("KW_TEMPORARY", "TEMPORARY"); + xlateMap.put("KW_FUNCTION", "FUNCTION"); + xlateMap.put("KW_EXPLAIN", "EXPLAIN"); + xlateMap.put("KW_EXTENDED", "EXTENDED"); + xlateMap.put("KW_SERDE", "SERDE"); + xlateMap.put("KW_WITH", "WITH"); + xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); + xlateMap.put("KW_LIMIT", "LIMIT"); + xlateMap.put("KW_SET", "SET"); + xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); + xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); + xlateMap.put("KW_DEFINED", "DEFINED"); + xlateMap.put("KW_SUBQUERY", "SUBQUERY"); + xlateMap.put("KW_REWRITE", "REWRITE"); + xlateMap.put("KW_UPDATE", "UPDATE"); + xlateMap.put("KW_VALUES", "VALUES"); + xlateMap.put("KW_PURGE", "PURGE"); + + + // Operators + xlateMap.put("DOT", "."); + xlateMap.put("COLON", ":"); + xlateMap.put("COMMA", ","); + xlateMap.put("SEMICOLON", ");"); + + xlateMap.put("LPAREN", "("); + xlateMap.put("RPAREN", ")"); + xlateMap.put("LSQUARE", "["); + xlateMap.put("RSQUARE", "]"); + + xlateMap.put("EQUAL", "="); + xlateMap.put("NOTEQUAL", "<>"); + xlateMap.put("EQUAL_NS", "<=>"); + xlateMap.put("LESSTHANOREQUALTO", "<="); + xlateMap.put("LESSTHAN", "<"); + xlateMap.put("GREATERTHANOREQUALTO", ">="); + xlateMap.put("GREATERTHAN", ">"); + + xlateMap.put("DIVIDE", "/"); + xlateMap.put("PLUS", "+"); + xlateMap.put("MINUS", "-"); + xlateMap.put("STAR", "*"); + xlateMap.put("MOD", "\%"); + + xlateMap.put("AMPERSAND", "&"); + xlateMap.put("TILDE", "~"); + xlateMap.put("BITWISEOR", "|"); + xlateMap.put("BITWISEXOR", "^"); + xlateMap.put("CharSetLiteral", "\\'"); + } + + public static Collection getKeywords() { + return xlateMap.values(); + } + + private static String xlate(String name) { + + String ret = xlateMap.get(name); + if (ret == null) { + ret = name; + } + + return ret; + } + + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + errors.add(new ParseError(this, e, tokenNames)); + } + + @Override + public String getErrorHeader(RecognitionException e) { + String header = null; + if (e.charPositionInLine < 0 && input.LT(-1) != null) { + Token t = input.LT(-1); + header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); + } else { + header = super.getErrorHeader(e); + } + + return header; + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + // Translate the token names to something that the user can understand + String[] xlateNames = new String[tokenNames.length]; + for (int i = 0; i < tokenNames.length; ++i) { + xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); + } + + if (e instanceof NoViableAltException) { + @SuppressWarnings("unused") + NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "cannot recognize input near" + + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") + + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") + + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); + } else if (e instanceof MismatchedTokenException) { + MismatchedTokenException mte = (MismatchedTokenException) e; + msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; + } else if (e instanceof FailedPredicateException) { + FailedPredicateException fpe = (FailedPredicateException) e; + msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; + } else { + msg = super.getErrorMessage(e, xlateNames); + } + + if (msgs.size() > 0) { + msg = msg + " in " + msgs.peek(); + } + return msg; + } + + public void pushMsg(String msg, RecognizerSharedState state) { + // ANTLR generated code does not wrap the @init code wit this backtracking check, + // even if the matching @after has it. If we have parser rules with that are doing + // some lookahead with syntactic predicates this can cause the push() and pop() calls + // to become unbalanced, so make sure both push/pop check the backtracking state. + if (state.backtracking == 0) { + msgs.push(msg); + } + } + + public void popMsg(RecognizerSharedState state) { + if (state.backtracking == 0) { + Object o = msgs.pop(); + } + } + + // counter to generate unique union aliases + private int aliasCounter; + private String generateUnionAlias() { + return "_u" + (++aliasCounter); + } + private char [] excludedCharForColumnName = {'.', ':'}; + private boolean containExcludedCharForCreateTableColumnName(String input) { + for(char c : excludedCharForColumnName) { + if(input.indexOf(c)>-1) { + return true; + } + } + return false; + } + private CommonTree throwSetOpException() throws RecognitionException { + throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); + } + private CommonTree throwColumnNameException() throws RecognitionException { + throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); + } + private Configuration hiveConf; + public void setHiveConf(Configuration hiveConf) { + this.hiveConf = hiveConf; + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + if(hiveConf==null){ + return false; + } + return !HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS); + } +} + +@rulecatch { +catch (RecognitionException e) { + reportError(e); + throw e; +} +} + +// starting rule +statement + : explainStatement EOF + | execStatement EOF + ; + +explainStatement +@init { pushMsg("explain statement", state); } +@after { popMsg(state); } + : KW_EXPLAIN ( + explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) + | + KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) + ; + +explainOption +@init { msgs.push("explain option"); } +@after { msgs.pop(); } + : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION + ; + +execStatement +@init { pushMsg("statement", state); } +@after { popMsg(state); } + : queryStatementExpression[true] + | loadStatement + | exportStatement + | importStatement + | ddlStatement + | deleteStatement + | updateStatement + | sqlTransactionStatement + ; + +loadStatement +@init { pushMsg("load statement", state); } +@after { popMsg(state); } + : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) + -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) + ; + +replicationClause +@init { pushMsg("replication clause", state); } +@after { popMsg(state); } + : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN + -> ^(TOK_REPLICATION $replId $isMetadataOnly?) + ; + +exportStatement +@init { pushMsg("export statement", state); } +@after { popMsg(state); } + : KW_EXPORT + KW_TABLE (tab=tableOrPartition) + KW_TO (path=StringLiteral) + replicationClause? + -> ^(TOK_EXPORT $tab $path replicationClause?) + ; + +importStatement +@init { pushMsg("import statement", state); } +@after { popMsg(state); } + : KW_IMPORT + ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? + KW_FROM (path=StringLiteral) + tableLocation? + -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) + ; + +ddlStatement +@init { pushMsg("ddl statement", state); } +@after { popMsg(state); } + : createDatabaseStatement + | switchDatabaseStatement + | dropDatabaseStatement + | createTableStatement + | dropTableStatement + | truncateTableStatement + | alterStatement + | descStatement + | showStatement + | metastoreCheck + | createViewStatement + | dropViewStatement + | createFunctionStatement + | createMacroStatement + | createIndexStatement + | dropIndexStatement + | dropFunctionStatement + | reloadFunctionStatement + | dropMacroStatement + | analyzeStatement + | lockStatement + | unlockStatement + | lockDatabase + | unlockDatabase + | createRoleStatement + | dropRoleStatement + | (grantPrivileges) => grantPrivileges + | (revokePrivileges) => revokePrivileges + | showGrants + | showRoleGrants + | showRolePrincipals + | showRoles + | grantRole + | revokeRole + | setRole + | showCurrentRole + ; + +ifExists +@init { pushMsg("if exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_EXISTS + -> ^(TOK_IFEXISTS) + ; + +restrictOrCascade +@init { pushMsg("restrict or cascade clause", state); } +@after { popMsg(state); } + : KW_RESTRICT + -> ^(TOK_RESTRICT) + | KW_CASCADE + -> ^(TOK_CASCADE) + ; + +ifNotExists +@init { pushMsg("if not exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_NOT KW_EXISTS + -> ^(TOK_IFNOTEXISTS) + ; + +storedAsDirs +@init { pushMsg("stored as directories", state); } +@after { popMsg(state); } + : KW_STORED KW_AS KW_DIRECTORIES + -> ^(TOK_STOREDASDIRS) + ; + +orReplace +@init { pushMsg("or replace clause", state); } +@after { popMsg(state); } + : KW_OR KW_REPLACE + -> ^(TOK_ORREPLACE) + ; + +createDatabaseStatement +@init { pushMsg("create database statement", state); } +@after { popMsg(state); } + : KW_CREATE (KW_DATABASE|KW_SCHEMA) + ifNotExists? + name=identifier + databaseComment? + dbLocation? + (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? + -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) + ; + +dbLocation +@init { pushMsg("database location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) + ; + +dbProperties +@init { pushMsg("dbproperties", state); } +@after { popMsg(state); } + : + LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) + ; + +dbPropertiesList +@init { pushMsg("database properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) + ; + + +switchDatabaseStatement +@init { pushMsg("switch database statement", state); } +@after { popMsg(state); } + : KW_USE identifier + -> ^(TOK_SWITCHDATABASE identifier) + ; + +dropDatabaseStatement +@init { pushMsg("drop database statement", state); } +@after { popMsg(state); } + : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? + -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) + ; + +databaseComment +@init { pushMsg("database's comment", state); } +@after { popMsg(state); } + : KW_COMMENT comment=StringLiteral + -> ^(TOK_DATABASECOMMENT $comment) + ; + +createTableStatement +@init { pushMsg("create table statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName + ( like=KW_LIKE likeName=tableName + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + | (LPAREN columnNameTypeList RPAREN)? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + (KW_AS selectStatementWithCTE)? + ) + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + columnNameTypeList? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + selectStatementWithCTE? + ) + ; + +truncateTableStatement +@init { pushMsg("truncate table statement", state); } +@after { popMsg(state); } + : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); + +createIndexStatement +@init { pushMsg("create index statement", state);} +@after {popMsg(state);} + : KW_CREATE KW_INDEX indexName=identifier + KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN + KW_AS typeName=StringLiteral + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment? + ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment?) + ; + +indexComment +@init { pushMsg("comment on an index", state);} +@after {popMsg(state);} + : + KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) + ; + +autoRebuild +@init { pushMsg("auto rebuild index", state);} +@after {popMsg(state);} + : KW_WITH KW_DEFERRED KW_REBUILD + ->^(TOK_DEFERRED_REBUILDINDEX) + ; + +indexTblName +@init { pushMsg("index table name", state);} +@after {popMsg(state);} + : KW_IN KW_TABLE indexTbl=tableName + ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) + ; + +indexPropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_IDXPROPERTIES! indexProperties + ; + +indexProperties +@init { pushMsg("index properties", state); } +@after { popMsg(state); } + : + LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) + ; + +indexPropertiesList +@init { pushMsg("index properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) + ; + +dropIndexStatement +@init { pushMsg("drop index statement", state);} +@after {popMsg(state);} + : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName + ->^(TOK_DROPINDEX $indexName $tab ifExists?) + ; + +dropTableStatement +@init { pushMsg("drop statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? + -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) + ; + +alterStatement +@init { pushMsg("alter statement", state); } +@after { popMsg(state); } + : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) + | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) + | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix + | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix + ; + +alterTableStatementSuffix +@init { pushMsg("alter table statement", state); } +@after { popMsg(state); } + : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] + | alterStatementSuffixDropPartitions[true] + | alterStatementSuffixAddPartitions[true] + | alterStatementSuffixTouch + | alterStatementSuffixArchive + | alterStatementSuffixUnArchive + | alterStatementSuffixProperties + | alterStatementSuffixSkewedby + | alterStatementSuffixExchangePartition + | alterStatementPartitionKeyType + | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? + ; + +alterTblPartitionStatementSuffix +@init {pushMsg("alter table partition statement suffix", state);} +@after {popMsg(state);} + : alterStatementSuffixFileFormat + | alterStatementSuffixLocation + | alterStatementSuffixMergeFiles + | alterStatementSuffixSerdeProperties + | alterStatementSuffixRenamePart + | alterStatementSuffixBucketNum + | alterTblPartitionStatementSuffixSkewedLocation + | alterStatementSuffixClusterbySortby + | alterStatementSuffixCompact + | alterStatementSuffixUpdateStatsCol + | alterStatementSuffixRenameCol + | alterStatementSuffixAddCol + ; + +alterStatementPartitionKeyType +@init {msgs.push("alter partition key type"); } +@after {msgs.pop();} + : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN + -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) + ; + +alterViewStatementSuffix +@init { pushMsg("alter view statement", state); } +@after { popMsg(state); } + : alterViewSuffixProperties + | alterStatementSuffixRename[false] + | alterStatementSuffixAddPartitions[false] + | alterStatementSuffixDropPartitions[false] + | selectStatementWithCTE + ; + +alterIndexStatementSuffix +@init { pushMsg("alter index statement", state); } +@after { popMsg(state); } + : indexName=identifier KW_ON tableName partitionSpec? + ( + KW_REBUILD + ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) + | + KW_SET KW_IDXPROPERTIES + indexProperties + ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) + ) + ; + +alterDatabaseStatementSuffix +@init { pushMsg("alter database statement", state); } +@after { popMsg(state); } + : alterDatabaseSuffixProperties + | alterDatabaseSuffixSetOwner + ; + +alterDatabaseSuffixProperties +@init { pushMsg("alter database properties statement", state); } +@after { popMsg(state); } + : name=identifier KW_SET KW_DBPROPERTIES dbProperties + -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) + ; + +alterDatabaseSuffixSetOwner +@init { pushMsg("alter database set owner", state); } +@after { popMsg(state); } + : dbName=identifier KW_SET KW_OWNER principalName + -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) + ; + +alterStatementSuffixRename[boolean table] +@init { pushMsg("rename statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO tableName + -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) + -> ^(TOK_ALTERVIEW_RENAME tableName) + ; + +alterStatementSuffixAddCol +@init { pushMsg("add column statement", state); } +@after { popMsg(state); } + : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? + -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) + -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) + ; + +alterStatementSuffixRenameCol +@init { pushMsg("rename column name", state); } +@after { popMsg(state); } + : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? + ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) + ; + +alterStatementSuffixUpdateStatsCol +@init { pushMsg("update column statistics", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementChangeColPosition + : first=KW_FIRST|KW_AFTER afterCol=identifier + ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) + -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) + ; + +alterStatementSuffixAddPartitions[boolean table] +@init { pushMsg("add partition statement", state); } +@after { popMsg(state); } + : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ + -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + ; + +alterStatementSuffixAddPartitionsElement + : partitionSpec partitionLocation? + ; + +alterStatementSuffixTouch +@init { pushMsg("touch statement", state); } +@after { popMsg(state); } + : KW_TOUCH (partitionSpec)* + -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) + ; + +alterStatementSuffixArchive +@init { pushMsg("archive statement", state); } +@after { popMsg(state); } + : KW_ARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) + ; + +alterStatementSuffixUnArchive +@init { pushMsg("unarchive statement", state); } +@after { popMsg(state); } + : KW_UNARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) + ; + +partitionLocation +@init { pushMsg("partition location", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) + ; + +alterStatementSuffixDropPartitions[boolean table] +@init { pushMsg("drop partition statement", state); } +@after { popMsg(state); } + : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? + -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) + -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) + ; + +alterStatementSuffixProperties +@init { pushMsg("alter properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) + ; + +alterViewSuffixProperties +@init { pushMsg("alter view properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) + ; + +alterStatementSuffixSerdeProperties +@init { pushMsg("alter serdes statement", state); } +@after { popMsg(state); } + : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? + -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) + | KW_SET KW_SERDEPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) + ; + +tablePartitionPrefix +@init {pushMsg("table partition prefix", state);} +@after {popMsg(state);} + : tableName partitionSpec? + ->^(TOK_TABLE_PARTITION tableName partitionSpec?) + ; + +alterStatementSuffixFileFormat +@init {pushMsg("alter fileformat statement", state); } +@after {popMsg(state);} + : KW_SET KW_FILEFORMAT fileFormat + -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) + ; + +alterStatementSuffixClusterbySortby +@init {pushMsg("alter partition cluster by sort by statement", state);} +@after {popMsg(state);} + : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) + | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) + | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) + ; + +alterTblPartitionStatementSuffixSkewedLocation +@init {pushMsg("alter partition skewed location", state);} +@after {popMsg(state);} + : KW_SET KW_SKEWED KW_LOCATION skewedLocations + -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) + ; + +skewedLocations +@init { pushMsg("skewed locations", state); } +@after { popMsg(state); } + : + LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) + ; + +skewedLocationsList +@init { pushMsg("skewed locations list", state); } +@after { popMsg(state); } + : + skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) + ; + +skewedLocationMap +@init { pushMsg("specifying skewed location map", state); } +@after { popMsg(state); } + : + key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) + ; + +alterStatementSuffixLocation +@init {pushMsg("alter location", state);} +@after {popMsg(state);} + : KW_SET KW_LOCATION newLoc=StringLiteral + -> ^(TOK_ALTERTABLE_LOCATION $newLoc) + ; + + +alterStatementSuffixSkewedby +@init {pushMsg("alter skewed by statement", state);} +@after{popMsg(state);} + : tableSkewed + ->^(TOK_ALTERTABLE_SKEWED tableSkewed) + | + KW_NOT KW_SKEWED + ->^(TOK_ALTERTABLE_SKEWED) + | + KW_NOT storedAsDirs + ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) + ; + +alterStatementSuffixExchangePartition +@init {pushMsg("alter exchange partition", state);} +@after{popMsg(state);} + : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName + -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) + ; + +alterStatementSuffixRenamePart +@init { pushMsg("alter table rename partition statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO partitionSpec + ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) + ; + +alterStatementSuffixStatsPart +@init { pushMsg("alter table stats partition statement", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementSuffixMergeFiles +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_CONCATENATE + -> ^(TOK_ALTERTABLE_MERGEFILES) + ; + +alterStatementSuffixBucketNum +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $num) + ; + +alterStatementSuffixCompact +@init { msgs.push("compaction request"); } +@after { msgs.pop(); } + : KW_COMPACT compactType=StringLiteral + -> ^(TOK_ALTERTABLE_COMPACT $compactType) + ; + + +fileFormat +@init { pushMsg("file format specification", state); } +@after { popMsg(state); } + : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) + | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tabTypeExpr +@init { pushMsg("specifying table types", state); } +@after { popMsg(state); } + : identifier (DOT^ identifier)? + (identifier (DOT^ + ( + (KW_ELEM_TYPE) => KW_ELEM_TYPE + | + (KW_KEY_TYPE) => KW_KEY_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE + | identifier + ))* + )? + ; + +partTypeExpr +@init { pushMsg("specifying table partitions", state); } +@after { popMsg(state); } + : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) + ; + +tabPartColTypeExpr +@init { pushMsg("specifying table partitions columnName", state); } +@after { popMsg(state); } + : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) + ; + +descStatement +@init { pushMsg("describe statement", state); } +@after { popMsg(state); } + : + (KW_DESCRIBE|KW_DESC) + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) + | + (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) + | + (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) + | + parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) + ) + ; + +analyzeStatement +@init { pushMsg("analyze statement", state); } +@after { popMsg(state); } + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? + -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) + ; + +showStatement +@init { pushMsg("show statement", state); } +@after { popMsg(state); } + : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) + | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) + | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWCOLUMNS tableName $db_name?) + | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_CREATE ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) + | + KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) + ) + | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? + -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) + | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) + | KW_SHOW KW_LOCKS + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) + | + (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) + ) + | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) + | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) + | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) + | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) + ; + +lockStatement +@init { pushMsg("lock statement", state); } +@after { popMsg(state); } + : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) + ; + +lockDatabase +@init { pushMsg("lock database statement", state); } +@after { popMsg(state); } + : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) + ; + +lockMode +@init { pushMsg("lock mode", state); } +@after { popMsg(state); } + : KW_SHARED | KW_EXCLUSIVE + ; + +unlockStatement +@init { pushMsg("unlock statement", state); } +@after { popMsg(state); } + : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) + ; + +unlockDatabase +@init { pushMsg("unlock database statement", state); } +@after { popMsg(state); } + : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) + ; + +createRoleStatement +@init { pushMsg("create role", state); } +@after { popMsg(state); } + : KW_CREATE KW_ROLE roleName=identifier + -> ^(TOK_CREATEROLE $roleName) + ; + +dropRoleStatement +@init {pushMsg("drop role", state);} +@after {popMsg(state);} + : KW_DROP KW_ROLE roleName=identifier + -> ^(TOK_DROPROLE $roleName) + ; + +grantPrivileges +@init {pushMsg("grant privileges", state);} +@after {popMsg(state);} + : KW_GRANT privList=privilegeList + privilegeObject? + KW_TO principalSpecification + withGrantOption? + -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) + ; + +revokePrivileges +@init {pushMsg("revoke privileges", state);} +@afer {popMsg(state);} + : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification + -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) + ; + +grantRole +@init {pushMsg("grant role", state);} +@after {popMsg(state);} + : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? + -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) + ; + +revokeRole +@init {pushMsg("revoke role", state);} +@after {popMsg(state);} + : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification + -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) + ; + +showRoleGrants +@init {pushMsg("show role grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLE KW_GRANT principalName + -> ^(TOK_SHOW_ROLE_GRANT principalName) + ; + + +showRoles +@init {pushMsg("show roles", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLES + -> ^(TOK_SHOW_ROLES) + ; + +showCurrentRole +@init {pushMsg("show current role", state);} +@after {popMsg(state);} + : KW_SHOW KW_CURRENT KW_ROLES + -> ^(TOK_SHOW_SET_ROLE) + ; + +setRole +@init {pushMsg("set role", state);} +@after {popMsg(state);} + : KW_SET KW_ROLE + ( + (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) + | + (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) + | + identifier -> ^(TOK_SHOW_SET_ROLE identifier) + ) + ; + +showGrants +@init {pushMsg("show grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? + -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) + ; + +showRolePrincipals +@init {pushMsg("show role principals", state);} +@after {popMsg(state);} + : KW_SHOW KW_PRINCIPALS roleName=identifier + -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) + ; + + +privilegeIncludeColObject +@init {pushMsg("privilege object including columns", state);} +@after {popMsg(state);} + : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) + | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) + ; + +privilegeObject +@init {pushMsg("privilege object", state);} +@after {popMsg(state);} + : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) + ; + +// database or table type. Type is optional, default type is table +privObject + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privObjectCols + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privilegeList +@init {pushMsg("grant privilege list", state);} +@after {popMsg(state);} + : privlegeDef (COMMA privlegeDef)* + -> ^(TOK_PRIVILEGE_LIST privlegeDef+) + ; + +privlegeDef +@init {pushMsg("grant privilege", state);} +@after {popMsg(state);} + : privilegeType (LPAREN cols=columnNameList RPAREN)? + -> ^(TOK_PRIVILEGE privilegeType $cols?) + ; + +privilegeType +@init {pushMsg("privilege type", state);} +@after {popMsg(state);} + : KW_ALL -> ^(TOK_PRIV_ALL) + | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) + | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) + | KW_CREATE -> ^(TOK_PRIV_CREATE) + | KW_DROP -> ^(TOK_PRIV_DROP) + | KW_INDEX -> ^(TOK_PRIV_INDEX) + | KW_LOCK -> ^(TOK_PRIV_LOCK) + | KW_SELECT -> ^(TOK_PRIV_SELECT) + | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) + | KW_INSERT -> ^(TOK_PRIV_INSERT) + | KW_DELETE -> ^(TOK_PRIV_DELETE) + ; + +principalSpecification +@init { pushMsg("user/group/role name list", state); } +@after { popMsg(state); } + : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) + ; + +principalName +@init {pushMsg("user|group|role name", state);} +@after {popMsg(state);} + : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) + | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) + | KW_ROLE identifier -> ^(TOK_ROLE identifier) + ; + +withGrantOption +@init {pushMsg("with grant option", state);} +@after {popMsg(state);} + : KW_WITH KW_GRANT KW_OPTION + -> ^(TOK_GRANT_WITH_OPTION) + ; + +grantOptionFor +@init {pushMsg("grant option for", state);} +@after {popMsg(state);} + : KW_GRANT KW_OPTION KW_FOR + -> ^(TOK_GRANT_OPTION_FOR) +; + +adminOptionFor +@init {pushMsg("admin option for", state);} +@after {popMsg(state);} + : KW_ADMIN KW_OPTION KW_FOR + -> ^(TOK_ADMIN_OPTION_FOR) +; + +withAdminOption +@init {pushMsg("with admin option", state);} +@after {popMsg(state);} + : KW_WITH KW_ADMIN KW_OPTION + -> ^(TOK_GRANT_WITH_ADMIN_OPTION) + ; + +metastoreCheck +@init { pushMsg("metastore check statement", state); } +@after { popMsg(state); } + : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? + -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) + ; + +resourceList +@init { pushMsg("resource list", state); } +@after { popMsg(state); } + : + resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) + ; + +resource +@init { pushMsg("resource", state); } +@after { popMsg(state); } + : + resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) + ; + +resourceType +@init { pushMsg("resource type", state); } +@after { popMsg(state); } + : + KW_JAR -> ^(TOK_JAR) + | + KW_FILE -> ^(TOK_FILE) + | + KW_ARCHIVE -> ^(TOK_ARCHIVE) + ; + +createFunctionStatement +@init { pushMsg("create function statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral + (KW_USING rList=resourceList)? + -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) + -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) + ; + +dropFunctionStatement +@init { pushMsg("drop function statement", state); } +@after { popMsg(state); } + : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier + -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) + -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) + ; + +reloadFunctionStatement +@init { pushMsg("reload function statement", state); } +@after { popMsg(state); } + : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); + +createMacroStatement +@init { pushMsg("create macro statement", state); } +@after { popMsg(state); } + : KW_CREATE KW_TEMPORARY KW_MACRO Identifier + LPAREN columnNameTypeList? RPAREN expression + -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) + ; + +dropMacroStatement +@init { pushMsg("drop macro statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier + -> ^(TOK_DROPMACRO Identifier ifExists?) + ; + +createViewStatement +@init { + pushMsg("create view statement", state); +} +@after { popMsg(state); } + : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName + (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? + tablePropertiesPrefixed? + KW_AS + selectStatementWithCTE + -> ^(TOK_CREATEVIEW $name orReplace? + ifNotExists? + columnNameCommentList? + tableComment? + viewPartition? + tablePropertiesPrefixed? + selectStatementWithCTE + ) + ; + +viewPartition +@init { pushMsg("view partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN + -> ^(TOK_VIEWPARTCOLS columnNameList) + ; + +dropViewStatement +@init { pushMsg("drop view statement", state); } +@after { popMsg(state); } + : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) + ; + +showFunctionIdentifier +@init { pushMsg("identifier for show function statement", state); } +@after { popMsg(state); } + : functionIdentifier + | StringLiteral + ; + +showStmtIdentifier +@init { pushMsg("identifier for show statement", state); } +@after { popMsg(state); } + : identifier + | StringLiteral + ; + +tableComment +@init { pushMsg("table's comment", state); } +@after { popMsg(state); } + : + KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) + ; + +tablePartition +@init { pushMsg("table partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN + -> ^(TOK_TABLEPARTCOLS columnNameTypeList) + ; + +tableBuckets +@init { pushMsg("table buckets specification", state); } +@after { popMsg(state); } + : + KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) + ; + +tableSkewed +@init { pushMsg("table skewed specification", state); } +@after { popMsg(state); } + : + KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? + -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) + ; + +rowFormat +@init { pushMsg("serde specification", state); } +@after { popMsg(state); } + : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) + | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) + | -> ^(TOK_SERDE) + ; + +recordReader +@init { pushMsg("record reader specification", state); } +@after { popMsg(state); } + : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) + | -> ^(TOK_RECORDREADER) + ; + +recordWriter +@init { pushMsg("record writer specification", state); } +@after { popMsg(state); } + : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) + | -> ^(TOK_RECORDWRITER) + ; + +rowFormatSerde +@init { pushMsg("serde format specification", state); } +@after { popMsg(state); } + : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_SERDENAME $name $serdeprops?) + ; + +rowFormatDelimited +@init { pushMsg("serde properties specification", state); } +@after { popMsg(state); } + : + KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? + -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) + ; + +tableRowFormat +@init { pushMsg("table row format specification", state); } +@after { popMsg(state); } + : + rowFormatDelimited + -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) + | rowFormatSerde + -> ^(TOK_TABLESERIALIZER rowFormatSerde) + ; + +tablePropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_TBLPROPERTIES! tableProperties + ; + +tableProperties +@init { pushMsg("table properties", state); } +@after { popMsg(state); } + : + LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) + ; + +tablePropertiesList +@init { pushMsg("table properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) + | + keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) + ; + +keyValueProperty +@init { pushMsg("specifying key/value property", state); } +@after { popMsg(state); } + : + key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) + ; + +keyProperty +@init { pushMsg("specifying key property", state); } +@after { popMsg(state); } + : + key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) + ; + +tableRowFormatFieldIdentifier +@init { pushMsg("table row format's field separator", state); } +@after { popMsg(state); } + : + KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? + -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) + ; + +tableRowFormatCollItemsIdentifier +@init { pushMsg("table row format's column separator", state); } +@after { popMsg(state); } + : + KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) + ; + +tableRowFormatMapKeysIdentifier +@init { pushMsg("table row format's map key separator", state); } +@after { popMsg(state); } + : + KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) + ; + +tableRowFormatLinesIdentifier +@init { pushMsg("table row format's line separator", state); } +@after { popMsg(state); } + : + KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) + ; + +tableRowNullFormat +@init { pushMsg("table row format's null specifier", state); } +@after { popMsg(state); } + : + KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) + ; +tableFileFormat +@init { pushMsg("table file format specification", state); } +@after { popMsg(state); } + : + (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) + | KW_STORED KW_BY storageHandler=StringLiteral + (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) + | KW_STORED KW_AS genericSpec=identifier + -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tableLocation +@init { pushMsg("table location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) + ; + +columnNameTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) + ; + +columnNameColonTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) + ; + +columnNameList +@init { pushMsg("column name list", state); } +@after { popMsg(state); } + : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) + ; + +columnName +@init { pushMsg("column name", state); } +@after { popMsg(state); } + : + identifier + ; + +extColumnName +@init { pushMsg("column name for complex types", state); } +@after { popMsg(state); } + : + identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* + ; + +columnNameOrderList +@init { pushMsg("column name order list", state); } +@after { popMsg(state); } + : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) + ; + +skewedValueElement +@init { pushMsg("skewed value element", state); } +@after { popMsg(state); } + : + skewedColumnValues + | skewedColumnValuePairList + ; + +skewedColumnValuePairList +@init { pushMsg("column value pair list", state); } +@after { popMsg(state); } + : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) + ; + +skewedColumnValuePair +@init { pushMsg("column value pair", state); } +@after { popMsg(state); } + : + LPAREN colValues=skewedColumnValues RPAREN + -> ^(TOK_TABCOLVALUES $colValues) + ; + +skewedColumnValues +@init { pushMsg("column values", state); } +@after { popMsg(state); } + : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) + ; + +skewedColumnValue +@init { pushMsg("column value", state); } +@after { popMsg(state); } + : + constant + ; + +skewedValueLocationElement +@init { pushMsg("skewed value location element", state); } +@after { popMsg(state); } + : + skewedColumnValue + | skewedColumnValuePair + ; + +columnNameOrder +@init { pushMsg("column name order", state); } +@after { popMsg(state); } + : identifier (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) + -> ^(TOK_TABSORTCOLNAMEDESC identifier) + ; + +columnNameCommentList +@init { pushMsg("column name comment list", state); } +@after { popMsg(state); } + : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) + ; + +columnNameComment +@init { pushMsg("column name comment", state); } +@after { popMsg(state); } + : colName=identifier (KW_COMMENT comment=StringLiteral)? + -> ^(TOK_TABCOL $colName TOK_NULL $comment?) + ; + +columnRefOrder +@init { pushMsg("column order", state); } +@after { popMsg(state); } + : expression (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) + -> ^(TOK_TABSORTCOLNAMEDESC expression) + ; + +columnNameType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier colType (KW_COMMENT comment=StringLiteral)? + -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +columnNameColonType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +colType +@init { pushMsg("column type", state); } +@after { popMsg(state); } + : type + ; + +colTypeList +@init { pushMsg("column type list", state); } +@after { popMsg(state); } + : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) + ; + +type + : primitiveType + | listType + | structType + | mapType + | unionType; + +primitiveType +@init { pushMsg("primitive type specification", state); } +@after { popMsg(state); } + : KW_TINYINT -> TOK_TINYINT + | KW_SMALLINT -> TOK_SMALLINT + | KW_INT -> TOK_INT + | KW_BIGINT -> TOK_BIGINT + | KW_BOOLEAN -> TOK_BOOLEAN + | KW_FLOAT -> TOK_FLOAT + | KW_DOUBLE -> TOK_DOUBLE + | KW_DATE -> TOK_DATE + | KW_DATETIME -> TOK_DATETIME + | KW_TIMESTAMP -> TOK_TIMESTAMP + // Uncomment to allow intervals as table column types + //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH + //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME + | KW_STRING -> TOK_STRING + | KW_BINARY -> TOK_BINARY + | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) + | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) + | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) + ; + +listType +@init { pushMsg("list type", state); } +@after { popMsg(state); } + : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) + ; + +structType +@init { pushMsg("struct type", state); } +@after { popMsg(state); } + : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) + ; + +mapType +@init { pushMsg("map type", state); } +@after { popMsg(state); } + : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + -> ^(TOK_MAP $left $right) + ; + +unionType +@init { pushMsg("uniontype type", state); } +@after { popMsg(state); } + : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) + ; + +setOperator +@init { pushMsg("set operator", state); } +@after { popMsg(state); } + : KW_UNION KW_ALL -> ^(TOK_UNIONALL) + | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) + ; + +queryStatementExpression[boolean topLevel] + : + /* Would be nice to do this as a gated semantic perdicate + But the predicate gets pushed as a lookahead decision. + Calling rule doesnot know about topLevel + */ + (w=withClause {topLevel}?)? + queryStatementExpressionBody[topLevel] { + if ($w.tree != null) { + $queryStatementExpressionBody.tree.insertChild(0, $w.tree); + } + } + -> queryStatementExpressionBody + ; + +queryStatementExpressionBody[boolean topLevel] + : + fromStatement[topLevel] + | regularBody[topLevel] + ; + +withClause + : + KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) +; + +cteStatement + : + identifier KW_AS LPAREN queryStatementExpression[false] RPAREN + -> ^(TOK_SUBQUERY queryStatementExpression identifier) +; + +fromStatement[boolean topLevel] +: (singleFromStatement -> singleFromStatement) + (u=setOperator r=singleFromStatement + -> ^($u {$fromStatement.tree} $r) + )* + -> {u != null && topLevel}? ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$fromStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$fromStatement.tree} + ; + + +singleFromStatement + : + fromClause + ( b+=body )+ -> ^(TOK_QUERY fromClause body+) + ; + +/* +The valuesClause rule below ensures that the parse tree for +"insert into table FOO values (1,2),(3,4)" looks the same as +"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look +very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name +is implicit, it's represented as TOK_ANONYMOUS. +*/ +regularBody[boolean topLevel] + : + i=insertClause + ( + s=selectStatement[topLevel] + {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} + | + valuesClause + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) + ) + ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) + ) + ) + | + selectStatement[topLevel] + ; + +selectStatement[boolean topLevel] + : + ( + s=selectClause + f=fromClause? + w=whereClause? + g=groupByClause? + h=havingClause? + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + $s $w? $g? $h? $o? $c? + $d? $sort? $win? $l?)) + ) + (set=setOpSelectStatement[$selectStatement.tree, topLevel])? + -> {set == null}? + {$selectStatement.tree} + -> {o==null && c==null && d==null && sort==null && l==null}? + {$set.tree} + -> {throwSetOpException()} + ; + +setOpSelectStatement[CommonTree t, boolean topLevel] + : + (u=setOperator b=simpleSelectStatement + -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^(TOK_UNIONALL {$t} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> ^(TOK_UNIONALL {$t} $b) + )+ + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? + {$setOpSelectStatement.tree} + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$setOpSelectStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + $o? $c? $d? $sort? $win? $l? + ) + ) + ; + +simpleSelectStatement + : + selectClause + fromClause? + whereClause? + groupByClause? + havingClause? + ((window_clause) => window_clause)? + -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause whereClause? groupByClause? havingClause? window_clause?)) + ; + +selectStatementWithCTE + : + (w=withClause)? + selectStatement[true] { + if ($w.tree != null) { + $selectStatement.tree.insertChild(0, $w.tree); + } + } + -> selectStatement + ; + +body + : + insertClause + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT insertClause + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + | + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + ; + +insertClause +@init { pushMsg("insert clause", state); } +@after { popMsg(state); } + : + KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) + | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? + -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) + ; + +destination +@init { pushMsg("destination specification", state); } +@after { popMsg(state); } + : + (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? + -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) + | KW_TABLE tableOrPartition -> tableOrPartition + ; + +limitClause +@init { pushMsg("limit clause", state); } +@after { popMsg(state); } + : + KW_LIMIT num=Number -> ^(TOK_LIMIT $num) + ; + +//DELETE FROM WHERE ...; +deleteStatement +@init { pushMsg("delete statement", state); } +@after { popMsg(state); } + : + KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) + ; + +/*SET = (3 + col2)*/ +columnAssignmentClause + : + tableOrColumn EQUAL^ precedencePlusExpression + ; + +/*SET col1 = 5, col2 = (4 + col4), ...*/ +setColumnsClause + : + KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) + ; + +/* + UPDATE
    + SET col1 = val1, col2 = val2... WHERE ... +*/ +updateStatement +@init { pushMsg("update statement", state); } +@after { popMsg(state); } + : + KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) + ; + +/* +BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of +"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. +*/ +sqlTransactionStatement +@init { pushMsg("transaction statement", state); } +@after { popMsg(state); } + : + startTransactionStatement + | commitStatement + | rollbackStatement + | setAutoCommitStatement + ; + +startTransactionStatement + : + KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) + ; + +transactionMode + : + isolationLevel + | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) + ; + +transactionAccessMode + : + KW_READ KW_ONLY -> TOK_TXN_READ_ONLY + | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE + ; + +isolationLevel + : + KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) + ; + +/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ +levelOfIsolation + : + KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT + ; + +commitStatement + : + KW_COMMIT ( KW_WORK )? -> TOK_COMMIT + ; + +rollbackStatement + : + KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK + ; +setAutoCommitStatement + : + KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) + ; +/* +END user defined transaction boundaries +*/ diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java new file mode 100644 index 000000000000..35ecdc5ad10a --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.antlr.runtime.RecognitionException; +import org.antlr.runtime.Token; +import org.antlr.runtime.TokenStream; +import org.antlr.runtime.tree.CommonErrorNode; + +public class ASTErrorNode extends ASTNode { + + /** + * + */ + private static final long serialVersionUID = 1L; + CommonErrorNode delegate; + + public ASTErrorNode(TokenStream input, Token start, Token stop, + RecognitionException e){ + delegate = new CommonErrorNode(input,start,stop,e); + } + + @Override + public boolean isNil() { return delegate.isNil(); } + + @Override + public int getType() { return delegate.getType(); } + + @Override + public String getText() { return delegate.getText(); } + @Override + public String toString() { return delegate.toString(); } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java new file mode 100644 index 000000000000..33d9322b628e --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java @@ -0,0 +1,245 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.antlr.runtime.Token; +import org.antlr.runtime.tree.CommonTree; +import org.antlr.runtime.tree.Tree; +import org.apache.hadoop.hive.ql.lib.Node; + +public class ASTNode extends CommonTree implements Node, Serializable { + private static final long serialVersionUID = 1L; + private transient StringBuffer astStr; + private transient int startIndx = -1; + private transient int endIndx = -1; + private transient ASTNode rootNode; + private transient boolean isValidASTStr; + + public ASTNode() { + } + + /** + * Constructor. + * + * @param t + * Token for the CommonTree Node + */ + public ASTNode(Token t) { + super(t); + } + + public ASTNode(ASTNode node) { + super(node); + } + + @Override + public Tree dupNode() { + return new ASTNode(this); + } + + /* + * (non-Javadoc) + * + * @see org.apache.hadoop.hive.ql.lib.Node#getChildren() + */ + @Override + public ArrayList getChildren() { + if (super.getChildCount() == 0) { + return null; + } + + ArrayList ret_vec = new ArrayList(); + for (int i = 0; i < super.getChildCount(); ++i) { + ret_vec.add((Node) super.getChild(i)); + } + + return ret_vec; + } + + /* + * (non-Javadoc) + * + * @see org.apache.hadoop.hive.ql.lib.Node#getName() + */ + @Override + public String getName() { + return (Integer.valueOf(super.getToken().getType())).toString(); + } + + public String dump() { + StringBuilder sb = new StringBuilder("\n"); + dump(sb, ""); + return sb.toString(); + } + + private StringBuilder dump(StringBuilder sb, String ws) { + sb.append(ws); + sb.append(toString()); + sb.append("\n"); + + ArrayList children = getChildren(); + if (children != null) { + for (Node node : getChildren()) { + if (node instanceof ASTNode) { + ((ASTNode) node).dump(sb, ws + " "); + } else { + sb.append(ws); + sb.append(" NON-ASTNODE!!"); + sb.append("\n"); + } + } + } + return sb; + } + + private ASTNode getRootNodeWithValidASTStr(boolean useMemoizedRoot) { + if (useMemoizedRoot && rootNode != null && rootNode.parent == null && + rootNode.hasValidMemoizedString()) { + return rootNode; + } + ASTNode retNode = this; + while (retNode.parent != null) { + retNode = (ASTNode) retNode.parent; + } + rootNode=retNode; + if (!rootNode.isValidASTStr) { + rootNode.astStr = new StringBuffer(); + rootNode.toStringTree(rootNode); + rootNode.isValidASTStr = true; + } + return retNode; + } + + private boolean hasValidMemoizedString() { + return isValidASTStr && astStr != null; + } + + private void resetRootInformation() { + // Reset the previously stored rootNode string + if (rootNode != null) { + rootNode.astStr = null; + rootNode.isValidASTStr = false; + } + } + + private int getMemoizedStringLen() { + return astStr == null ? 0 : astStr.length(); + } + + private String getMemoizedSubString(int start, int end) { + return (astStr == null || start < 0 || end > astStr.length() || start >= end) ? null : + astStr.subSequence(start, end).toString(); + } + + private void addtoMemoizedString(String string) { + if (astStr == null) { + astStr = new StringBuffer(); + } + astStr.append(string); + } + + @Override + public void setParent(Tree t) { + super.setParent(t); + resetRootInformation(); + } + + @Override + public void addChild(Tree t) { + super.addChild(t); + resetRootInformation(); + } + + @Override + public void addChildren(List kids) { + super.addChildren(kids); + resetRootInformation(); + } + + @Override + public void setChild(int i, Tree t) { + super.setChild(i, t); + resetRootInformation(); + } + + @Override + public void insertChild(int i, Object t) { + super.insertChild(i, t); + resetRootInformation(); + } + + @Override + public Object deleteChild(int i) { + Object ret = super.deleteChild(i); + resetRootInformation(); + return ret; + } + + @Override + public void replaceChildren(int startChildIndex, int stopChildIndex, Object t) { + super.replaceChildren(startChildIndex, stopChildIndex, t); + resetRootInformation(); + } + + @Override + public String toStringTree() { + + // The root might have changed because of tree modifications. + // Compute the new root for this tree and set the astStr. + getRootNodeWithValidASTStr(true); + + // If rootNotModified is false, then startIndx and endIndx will be stale. + if (startIndx >= 0 && endIndx <= rootNode.getMemoizedStringLen()) { + return rootNode.getMemoizedSubString(startIndx, endIndx); + } + return toStringTree(rootNode); + } + + private String toStringTree(ASTNode rootNode) { + this.rootNode = rootNode; + startIndx = rootNode.getMemoizedStringLen(); + // Leaf node + if ( children==null || children.size()==0 ) { + rootNode.addtoMemoizedString(this.toString()); + endIndx = rootNode.getMemoizedStringLen(); + return this.toString(); + } + if ( !isNil() ) { + rootNode.addtoMemoizedString("("); + rootNode.addtoMemoizedString(this.toString()); + rootNode.addtoMemoizedString(" "); + } + for (int i = 0; children!=null && i < children.size(); i++) { + ASTNode t = (ASTNode)children.get(i); + if ( i>0 ) { + rootNode.addtoMemoizedString(" "); + } + t.toStringTree(rootNode); + } + if ( !isNil() ) { + rootNode.addtoMemoizedString(")"); + } + endIndx = rootNode.getMemoizedStringLen(); + return rootNode.getMemoizedSubString(startIndx, endIndx); + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java new file mode 100644 index 000000000000..c77198b087cb --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.util.ArrayList; +import org.antlr.runtime.ANTLRStringStream; +import org.antlr.runtime.CharStream; +import org.antlr.runtime.NoViableAltException; +import org.antlr.runtime.RecognitionException; +import org.antlr.runtime.Token; +import org.antlr.runtime.TokenRewriteStream; +import org.antlr.runtime.TokenStream; +import org.antlr.runtime.tree.CommonTree; +import org.antlr.runtime.tree.CommonTreeAdaptor; +import org.antlr.runtime.tree.TreeAdaptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.hadoop.hive.ql.Context; + +/** + * ParseDriver. + * + */ +public class ParseDriver { + + private static final Logger LOG = LoggerFactory.getLogger("hive.ql.parse.ParseDriver"); + + /** + * ANTLRNoCaseStringStream. + * + */ + //This class provides and implementation for a case insensitive token checker + //for the lexical analysis part of antlr. By converting the token stream into + //upper case at the time when lexical rules are checked, this class ensures that the + //lexical rules need to just match the token with upper case letters as opposed to + //combination of upper case and lower case characters. This is purely used for matching lexical + //rules. The actual token text is stored in the same way as the user input without + //actually converting it into an upper case. The token values are generated by the consume() + //function of the super class ANTLRStringStream. The LA() function is the lookahead function + //and is purely used for matching lexical rules. This also means that the grammar will only + //accept capitalized tokens in case it is run from other tools like antlrworks which + //do not have the ANTLRNoCaseStringStream implementation. + public class ANTLRNoCaseStringStream extends ANTLRStringStream { + + public ANTLRNoCaseStringStream(String input) { + super(input); + } + + @Override + public int LA(int i) { + + int returnChar = super.LA(i); + if (returnChar == CharStream.EOF) { + return returnChar; + } else if (returnChar == 0) { + return returnChar; + } + + return Character.toUpperCase((char) returnChar); + } + } + + /** + * HiveLexerX. + * + */ + public class HiveLexerX extends SparkSqlLexer { + + private final ArrayList errors; + + public HiveLexerX(CharStream input) { + super(input); + errors = new ArrayList(); + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + errors.add(new ParseError(this, e, tokenNames)); + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + if (e instanceof NoViableAltException) { + // @SuppressWarnings("unused") + // NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "character " + getCharErrorDisplay(e.c) + " not supported here"; + } else { + msg = super.getErrorMessage(e, tokenNames); + } + + return msg; + } + + public ArrayList getErrors() { + return errors; + } + + } + + /** + * Tree adaptor for making antlr return ASTNodes instead of CommonTree nodes + * so that the graph walking algorithms and the rules framework defined in + * ql.lib can be used with the AST Nodes. + */ + public static final TreeAdaptor adaptor = new CommonTreeAdaptor() { + /** + * Creates an ASTNode for the given token. The ASTNode is a wrapper around + * antlr's CommonTree class that implements the Node interface. + * + * @param payload + * The token. + * @return Object (which is actually an ASTNode) for the token. + */ + @Override + public Object create(Token payload) { + return new ASTNode(payload); + } + + @Override + public Object dupNode(Object t) { + + return create(((CommonTree)t).token); + }; + + @Override + public Object errorNode(TokenStream input, Token start, Token stop, RecognitionException e) { + return new ASTErrorNode(input, start, stop, e); + }; + }; + + public ASTNode parse(String command) throws ParseException { + return parse(command, null); + } + + public ASTNode parse(String command, Context ctx) + throws ParseException { + return parse(command, ctx, true); + } + + /** + * Parses a command, optionally assigning the parser's token stream to the + * given context. + * + * @param command + * command to parse + * + * @param ctx + * context with which to associate this parser's token stream, or + * null if either no context is available or the context already has + * an existing stream + * + * @return parsed AST + */ + public ASTNode parse(String command, Context ctx, boolean setTokenRewriteStream) + throws ParseException { + LOG.info("Parsing command: " + command); + + HiveLexerX lexer = new HiveLexerX(new ANTLRNoCaseStringStream(command)); + TokenRewriteStream tokens = new TokenRewriteStream(lexer); + if (ctx != null) { + if ( setTokenRewriteStream) { + ctx.setTokenRewriteStream(tokens); + } + lexer.setHiveConf(ctx.getConf()); + } + SparkSqlParser parser = new SparkSqlParser(tokens); + if (ctx != null) { + parser.setHiveConf(ctx.getConf()); + } + parser.setTreeAdaptor(adaptor); + SparkSqlParser.statement_return r = null; + try { + r = parser.statement(); + } catch (RecognitionException e) { + e.printStackTrace(); + throw new ParseException(parser.errors); + } + + if (lexer.getErrors().size() == 0 && parser.errors.size() == 0) { + LOG.info("Parse Completed"); + } else if (lexer.getErrors().size() != 0) { + throw new ParseException(lexer.getErrors()); + } else { + throw new ParseException(parser.errors); + } + + ASTNode tree = (ASTNode) r.getTree(); + tree.setUnknownTokenBoundaries(); + return tree; + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java new file mode 100644 index 000000000000..b47bcfb2914d --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.antlr.runtime.BaseRecognizer; +import org.antlr.runtime.RecognitionException; + +/** + * + */ +public class ParseError { + private final BaseRecognizer br; + private final RecognitionException re; + private final String[] tokenNames; + + ParseError(BaseRecognizer br, RecognitionException re, String[] tokenNames) { + this.br = br; + this.re = re; + this.tokenNames = tokenNames; + } + + BaseRecognizer getBaseRecognizer() { + return br; + } + + RecognitionException getRecognitionException() { + return re; + } + + String[] getTokenNames() { + return tokenNames; + } + + String getMessage() { + return br.getErrorHeader(re) + " " + br.getErrorMessage(re, tokenNames); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java new file mode 100644 index 000000000000..fff891ced555 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.util.ArrayList; + +/** + * ParseException. + * + */ +public class ParseException extends Exception { + + private static final long serialVersionUID = 1L; + ArrayList errors; + + public ParseException(ArrayList errors) { + super(); + this.errors = errors; + } + + @Override + public String getMessage() { + + StringBuilder sb = new StringBuilder(); + for (ParseError err : errors) { + if (sb.length() > 0) { + sb.append('\n'); + } + sb.append(err.getMessage()); + } + + return sb.toString(); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java new file mode 100644 index 000000000000..a5c2998f86cc --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + + +/** + * Library of utility functions used in the parse code. + * + */ +public final class ParseUtils { + /** + * Performs a descent of the leftmost branch of a tree, stopping when either a + * node with a non-null token is found or the leaf level is encountered. + * + * @param tree + * candidate node from which to start searching + * + * @return node at which descent stopped + */ + public static ASTNode findRootNonNullToken(ASTNode tree) { + while ((tree.getToken() == null) && (tree.getChildCount() > 0)) { + tree = (org.apache.spark.sql.parser.ASTNode) tree.getChild(0); + } + return tree; + } + + private ParseUtils() { + // prevent instantiation + } + + public static VarcharTypeInfo getVarcharTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() != 1) { + throw new SemanticException("Bad params for type varchar"); + } + + String lengthStr = node.getChild(0).getText(); + return TypeInfoFactory.getVarcharTypeInfo(Integer.valueOf(lengthStr)); + } + + public static CharTypeInfo getCharTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() != 1) { + throw new SemanticException("Bad params for type char"); + } + + String lengthStr = node.getChild(0).getText(); + return TypeInfoFactory.getCharTypeInfo(Integer.valueOf(lengthStr)); + } + + public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() > 2) { + throw new SemanticException("Bad params for type decimal"); + } + + int precision = HiveDecimal.USER_DEFAULT_PRECISION; + int scale = HiveDecimal.USER_DEFAULT_SCALE; + + if (node.getChildCount() >= 1) { + String precStr = node.getChild(0).getText(); + precision = Integer.valueOf(precStr); + } + + if (node.getChildCount() == 2) { + String scaleStr = node.getChild(1).getText(); + scale = Integer.valueOf(scaleStr); + } + + return TypeInfoFactory.getDecimalTypeInfo(precision, scale); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java new file mode 100644 index 000000000000..4b2015e0df84 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java @@ -0,0 +1,406 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.antlr.runtime.tree.Tree; +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.ErrorMsg; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + +/** + * SemanticAnalyzer. + * + */ +public abstract class SemanticAnalyzer { + public static String charSetString(String charSetName, String charSetString) + throws SemanticException { + try { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), + charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + String res = new String(bArray, charSetName); + return res; + } + } catch (UnsupportedEncodingException e) { + throw new SemanticException(e); + } + } + + /** + * Remove the encapsulating "`" pair from the identifier. We allow users to + * use "`" to escape identifier for table names, column names and aliases, in + * case that coincide with Hive language keywords. + */ + public static String unescapeIdentifier(String val) { + if (val == null) { + return null; + } + if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { + val = val.substring(1, val.length() - 1); + } + return val; + } + + /** + * Converts parsed key/value properties pairs into a map. + * + * @param prop ASTNode parent of the key/value pairs + * + * @param mapProp property map which receives the mappings + */ + public static void readProps( + ASTNode prop, Map mapProp) { + + for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { + String key = unescapeSQLString(prop.getChild(propChild).getChild(0) + .getText()); + String value = null; + if (prop.getChild(propChild).getChild(1) != null) { + value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); + } + mapProp.put(key, value); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } + + /** + * Get the list of FieldSchema out of the ASTNode. + */ + public static List getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { + List colList = new ArrayList(); + int numCh = ast.getChildCount(); + for (int i = 0; i < numCh; i++) { + FieldSchema col = new FieldSchema(); + ASTNode child = (ASTNode) ast.getChild(i); + Tree grandChild = child.getChild(0); + if(grandChild != null) { + String name = grandChild.getText(); + if(lowerCase) { + name = name.toLowerCase(); + } + // child 0 is the name of the column + col.setName(unescapeIdentifier(name)); + // child 1 is the type of the column + ASTNode typeChild = (ASTNode) (child.getChild(1)); + col.setType(getTypeStringFromAST(typeChild)); + + // child 2 is the optional comment of the column + if (child.getChildCount() == 3) { + col.setComment(unescapeSQLString(child.getChild(2).getText())); + } + } + colList.add(col); + } + return colList; + } + + protected static String getTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + switch (typeNode.getType()) { + case SparkSqlParser.TOK_LIST: + return serdeConstants.LIST_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; + case SparkSqlParser.TOK_MAP: + return serdeConstants.MAP_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," + + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; + case SparkSqlParser.TOK_STRUCT: + return getStructTypeStringFromAST(typeNode); + case SparkSqlParser.TOK_UNIONTYPE: + return getUnionTypeStringFromAST(typeNode); + default: + return getTypeName(typeNode); + } + } + + private static String getStructTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty struct not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + ASTNode child = (ASTNode) typeNode.getChild(i); + buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); + buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); + if (i < children - 1) { + buffer.append(","); + } + } + + buffer.append(">"); + return buffer.toString(); + } + + private static String getUnionTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty union not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); + if (i < children - 1) { + buffer.append(","); + } + } + buffer.append(">"); + typeStr = buffer.toString(); + return typeStr; + } + + public static String getAstNodeText(ASTNode tree) { + return tree.getChildCount() == 0?tree.getText() : + getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); + } + + public static String generateErrorMessage(ASTNode ast, String message) { + StringBuilder sb = new StringBuilder(); + if (ast == null) { + sb.append(message).append(". Cannot tell the position of null AST."); + return sb.toString(); + } + sb.append(ast.getLine()); + sb.append(":"); + sb.append(ast.getCharPositionInLine()); + sb.append(" "); + sb.append(message); + sb.append(". Error encountered near token '"); + sb.append(getAstNodeText(ast)); + sb.append("'"); + return sb.toString(); + } + + private static final Map TokenToTypeName = new HashMap(); + + static { + TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); + } + + public static String getTypeName(ASTNode node) throws SemanticException { + int token = node.getType(); + String typeName; + + // datetime type isn't currently supported + if (token == SparkSqlParser.TOK_DATETIME) { + throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); + } + + switch (token) { + case SparkSqlParser.TOK_CHAR: + CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); + typeName = charTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_VARCHAR: + VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); + typeName = varcharTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_DECIMAL: + DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); + typeName = decTypeInfo.getQualifiedName(); + break; + default: + typeName = TokenToTypeName.get(token); + } + return typeName; + } + + public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { + boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); + if (testMode) { + URI uri = new Path(location).toUri(); + String scheme = uri.getScheme(); + String authority = uri.getAuthority(); + String path = uri.getPath(); + if (!path.startsWith("/")) { + path = (new Path(System.getProperty("test.tmp.dir"), + path)).toUri().getPath(); + } + if (StringUtils.isEmpty(scheme)) { + scheme = "pfile"; + } + try { + uri = new URI(scheme, authority, path, null, null); + } catch (URISyntaxException e) { + throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); + } + return uri.toString(); + } else { + //no-op for non-test mode for now + return location; + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 384ea211df84..5d00e7367026 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -380,7 +380,7 @@ class HiveContext private[hive]( def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDir) { + val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath().getName().startsWith(stagingDir)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0e89928cb636..b1d841d1b554 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -27,28 +27,28 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse._ +import org.apache.hadoop.hive.ql.parse.SemanticException import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, catalyst} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} +import org.apache.spark.sql.parser._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -227,7 +227,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren.asJava) + newChildren.foreach(n.addChild(_)) n } @@ -273,7 +273,8 @@ private[hive] object HiveQl extends Logging { private def createContext(): Context = new Context(hiveConf) private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) + ParseUtils.findRootNonNullToken( + (new ParseDriver).parse(sql, context)) /** * Returns the HiveConf @@ -312,7 +313,7 @@ private[hive] object HiveQl extends Logging { context.clear() plan } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => + case pe: ParseException => pe.getMessage match { case errorRegEx(line, start, message) => throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) @@ -337,7 +338,8 @@ private[hive] object HiveQl extends Logging { val tree = try { ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) + (new ParseDriver) + .parse(ddl, null /* no context required for parsing alone */)) } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) @@ -598,12 +600,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { tableType match { - case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { - nameParts.head match { + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => { + nameParts match { case Token(".", dbName :: tableName :: Nil) => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts.head) + val tableIdent = extractTableIdent(nameParts) DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => @@ -662,7 +664,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { val schema = maybeColumns.map { cols => - BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + SemanticAnalyzer.getColumns(cols, true).asScala.map { field => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. HiveColumn(field.getName, null, field.getComment) @@ -678,7 +680,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeComment.foreach { case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + val comment = SemanticAnalyzer.unescapeSQLString(child.getText) if (comment ne null) { properties += ("comment" -> comment) } @@ -750,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C children.collect { case list @ Token("TOK_TABCOLLIST", _) => - val cols = BaseSemanticAnalyzer.getColumns(list, true) + val cols = SemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( schema = cols.asScala.map { field => @@ -758,11 +760,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }) } case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + val comment = SemanticAnalyzer.unescapeSQLString(child.getText) // TODO support the sql text tableDesc = tableDesc.copy(viewText = Option(comment)) case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = BaseSemanticAnalyzer.getColumns(list(0), false) + val cols = SemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( partitionColumns = cols.asScala.map { field => @@ -773,21 +775,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) + val fieldDelim = SemanticAnalyzer.unescapeSQLString (rowChild1.getText()) serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) if (rowChild2.length > 1) { - val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) + val fieldEscape = SemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val collItemDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val mapKeyDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val lineDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) if (!(lineDelim == "\n") && !(lineDelim == "10")) { throw new AnalysisException( SemanticAnalyzer.generateErrorMessage( @@ -796,22 +798,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val nullFormat = SemanticAnalyzer.unescapeSQLString(rowChild.getText) // TODO support the nullFormat case _ => assert(false) } tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => - var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - location = EximUtil.relativeToAbsolutePath(hiveConf, location) + var location = SemanticAnalyzer.unescapeSQLString(child.getText) + location = SemanticAnalyzer.relativeToAbsolutePath(hiveConf, location) tableDesc = tableDesc.copy(location = Option(location)) case Token("TOK_TABLESERIALIZER", child :: Nil) => tableDesc = tableDesc.copy( - serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) + serde = Option(SemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) if (child.getChildCount == 2) { val serdeParams = new java.util.HashMap[String, String]() - BaseSemanticAnalyzer.readProps( + SemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) @@ -891,9 +893,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), + Option(SemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), outputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) + Option(SemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) case Token("TOK_STORAGEHANDLER", _) => throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) case _ => // Unsupport features @@ -909,24 +911,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) - if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + if Seq("TOK_CTE", "TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - // check if has CTE - insertClauses.last match { - case Token("TOK_CTE", cteClauses) => - val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - (relation.alias, relation) - }).toMap - (Some(args.head), insertClauses.init, Some(cteRelations)) - - case _ => (Some(args.head), insertClauses, None) + case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => + val cteRelations = ctes.map { node => + val relation = nodeToRelation(node, context).asInstanceOf[Subquery] + relation.alias -> relation } - - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + (Some(from.head), inserts, Some(cteRelations.toMap)) + case Token("TOK_FROM", from) :: inserts => + (Some(from.head), inserts, None) + case Token("TOK_INSERT", _) :: Nil => + (None, queryArgs, None) } // Return one query for each insert clause. @@ -1025,20 +1023,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) + (Nil, Some(SemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (BaseSemanticAnalyzer.unescapeSQLString(name), - BaseSemanticAnalyzer.unescapeSQLString(value)) + (SemanticAnalyzer.unescapeSQLString(name), + SemanticAnalyzer.unescapeSQLString(value)) } // SPARK-10310: Special cases LazySimpleSerDe // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val unescapedSerDeClass = SemanticAnalyzer.unescapeSQLString(serdeClass) val useDefaultRecordReaderWriter = unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) @@ -1055,7 +1053,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = matchSerDe(outputSerdeClause) - val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + val unescapedScript = SemanticAnalyzer.unescapeSQLString(script) // TODO Adds support for user-defined record reader/writer classes val recordReaderClass = if (useDefaultRecordReader) { @@ -1361,6 +1359,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_ANTIJOIN" => throw new NotImplementedError("Anti join not supported") } Join(nodeToRelation(relation1, context), nodeToRelation(relation2, context), @@ -1475,11 +1474,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val numericAstTypes = Seq( - HiveParser.Number, - HiveParser.TinyintLiteral, - HiveParser.SmallintLiteral, - HiveParser.BigintLiteral, - HiveParser.DecimalLiteral) + SparkSqlParser.Number, + SparkSqlParser.TinyintLiteral, + SparkSqlParser.SmallintLiteral, + SparkSqlParser.BigintLiteral, + SparkSqlParser.DecimalLiteral) /* Case insensitive matches */ val COUNT = "(?i)COUNT".r @@ -1649,7 +1648,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(TRUE(), Nil) => Literal.create(true, BooleanType) case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + Literal(strings.map(s => SemanticAnalyzer.unescapeSQLString(s.getText)).mkString) // This code is adapted from // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 @@ -1684,37 +1683,37 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C v } - case ast: ASTNode if ast.getType == HiveParser.StringLiteral => - Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == SparkSqlParser.StringLiteral => + Literal(SemanticAnalyzer.unescapeSQLString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => - Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_CHARSETLITERAL => + Literal(SemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => Literal(CalendarInterval.fromYearMonthString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => Literal(CalendarInterval.fromDayTimeString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 598ccdeee4ad..d3da22aa0ae5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -31,9 +31,7 @@ import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Driver, metadata} -import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkException, Logging} import org.apache.spark.sql.catalyst.expressions.Expression @@ -65,74 +63,6 @@ private[hive] class ClientWrapper( extends ClientInterface with Logging { - overrideHadoopShims() - - // !! HACK ALERT !! - // - // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking - // major version number gathered from Hadoop jar files: - // - // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with - // security. - // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. - // - // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It - // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but - // `Hadoop23Shims` is chosen because the major version number here is 2. - // - // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` - // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is - // not available, we decide whether to override the shims or not by checking for existence of a - // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. - private def overrideHadoopShims(): Unit = { - val hadoopVersion = VersionInfo.getVersion - val VersionPattern = """(\d+)\.(\d+).*""".r - - hadoopVersion match { - case null => - logError("Failed to inspect Hadoop version") - - // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. - val probeMethod = "getPathWithoutSchemeAndAuthority" - if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { - logInfo( - s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + - s"we are probably using Hadoop 1.x or 2.0.x") - loadHadoop20SShims() - } - - case VersionPattern(majorVersion, minorVersion) => - logInfo(s"Inspected Hadoop version: $hadoopVersion") - - // Loads Hadoop20SShims for 1.x and 2.0.x versions - val (major, minor) = (majorVersion.toInt, minorVersion.toInt) - if (major < 2 || (major == 2 && minor == 0)) { - loadHadoop20SShims() - } - } - - // Logs the actual loaded Hadoop shims class - val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName - logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") - } - - private def loadHadoop20SShims(): Unit = { - val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" - logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") - - try { - val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") - // scalastyle:off classforname - val shimsClass = Class.forName(hadoop20SShimsClassName) - // scalastyle:on classforname - val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) - shimsField.setAccessible(true) - shimsField.set(null, shims) - } catch { case cause: Throwable => - throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) - } - } - // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. private val outputBuffer = new CircularBuffer() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8141136de531..1588728bdbaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -132,11 +132,17 @@ case class HiveTableScan( } } - protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + protected override def doExecute(): RDD[InternalRow] = { + val rdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + } + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f936cf565b2b..44dc68e6ba47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,18 +28,17 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, Attribute} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.util.SerializableJobConf +import org.apache.spark.{SparkException, TaskContext} private[hive] case class InsertIntoHiveTable( @@ -101,15 +100,17 @@ case class InsertIntoHiveTable( writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + val proj = FromUnsafeProjection(child.schema) iterator.foreach { row => var i = 0 + val safeRow = proj(row) while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) i += 1 } writerContainer - .getLocalFileWriter(row, table.schema) + .getLocalFileWriter(safeRow, table.schema) .write(serializer.serialize(outputData, standardOI)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a61e162f48f1..6ccd4178190c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -213,7 +213,8 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => if (iter.hasNext) { - processIterator(iter) + val proj = UnsafeProjection.create(schema) + processIterator(iter).map(proj) } else { // If the input iterator has no rows then do not launch the external script. Iterator.empty diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 93c016b6c6c7..777e7857d2db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -27,9 +27,10 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ -import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} @@ -46,9 +47,7 @@ import org.apache.spark.util.SerializableJobConf private[hive] class SparkHiveWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { + extends Logging with Serializable { private val now = new Date() private val tableDesc: TableDesc = fileSinkConf.getTableInfo @@ -68,8 +67,8 @@ private[hive] class SparkHiveWriterContainer( @transient private var writer: FileSinkOperator.RecordWriter = null @transient protected lazy val committer = conf.value.getOutputCommitter - @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) - @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient protected lazy val jobContext = new JobContextImpl(conf.value, jID.value) + @transient private lazy val taskContext = new TaskAttemptContextImpl(conf.value, taID.value) @transient private lazy val outputFormat = conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] @@ -131,7 +130,7 @@ private[hive] class SparkHiveWriterContainer( jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } private def setConfParams() { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 0f9a1a6ef3b2..b91a14bdbcc4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -95,7 +95,7 @@ private[orc] object OrcFileOperator extends Logging { val fs = origPath.getFileSystem(conf) val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) - .filterNot(_.isDir) + .filterNot(_.isDirectory) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1136670b7a0e..84ef12a68e1b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -33,8 +33,6 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -67,7 +65,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with HiveInspectors { private val serializer = { val table = new Properties() @@ -77,7 +75,7 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration serde.initialize(configuration, table) serde } @@ -99,9 +97,9 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" @@ -208,7 +206,7 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { + job.getConfiguration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -289,8 +287,8 @@ private[orc] case class OrcTableScan( } def execute(): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 97792549bb7a..013fbab0a812 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -410,7 +410,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { try { // HACK: Hive is too noisy by default. org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => - log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + val logger = log.asInstanceOf[org.apache.log4j.Logger] + if (!logger.getName.contains("org.apache.spark")) { + logger.setLevel(org.apache.log4j.Level.WARN) + } } cacheManager.clearCache() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 01960fd2901b..e10d21d5e368 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -25,7 +25,6 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} @@ -53,9 +52,9 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 665e87e3e335..efbf9988ddc1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -27,7 +27,6 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -689,36 +688,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } - - test("HadoopFsRelation produces UnsafeRow") { - withTempTable("test_unsafe") { - withTempPath { dir => - val path = dir.getCanonicalPath - sqlContext.range(3).write.format(dataSourceName).save(path) - sqlContext.read - .format(dataSourceName) - .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) - .load(path) - .registerTempTable("test_unsafe") - - val df = sqlContext.sql( - """SELECT COUNT(*) - |FROM test_unsafe a JOIN test_unsafe b - |WHERE a.id = b.id - """.stripMargin) - - val plan = df.queryExecution.executedPlan - - assert( - plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, - s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): - |$plan - """.stripMargin) - - checkAnswer(df, Row(3)) - } - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 9418beec0d74..15ad2e27d372 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -224,7 +224,8 @@ private[streaming] class FileBasedWriteAheadLog( val logDirectoryPath = new Path(logDirectory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) pastLogs.clear() pastLogs ++= logFileInfo diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index 1185f30265f6..1f5c1d4369b5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -19,10 +19,7 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import scala.util.Try - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FSDataOutputStream import org.apache.spark.util.Utils @@ -34,11 +31,6 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) - private lazy val hadoopFlushMethod = { - // Use reflection to get the right flush operation - val cls = classOf[FSDataOutputStream] - Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption - } private var nextOffset = stream.getPos() private var closed = false @@ -62,7 +54,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: } private def flush() { - hadoopFlushMethod.foreach { _.invoke(stream) } + stream.hflush() // Useful for local file system where hflush/sync does not work (HADOOP-7844) stream.getWrappedStream.flush() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index beaae34535fd..a670c7d63819 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -705,7 +705,8 @@ object WriteAheadLogSuite { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { fileSystem.listStatus(logDirectoryPath).map { _.getPath() }.sortBy { _.getName().split("-")(1).toLong }.map {